"vscode:/vscode.git/clone" did not exist on "18a26fcfb1983af7fba69db9bdce7ba5e6a9945f"
Unverified Commit 3a9cdc32 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing auto bloom test. (#2699)

parent 78ce618c
......@@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED:
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
......@@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment