Unverified Commit a62de9ec authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Fix wrong dtype in PagedAttentionWithALiBi bias (#996)




---------
Signed-off-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent 4042d192
......@@ -73,7 +73,12 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias:
# Already set by a previous layer.
return
......@@ -196,7 +201,7 @@ class PagedAttention(nn.Module):
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
......@@ -340,13 +345,14 @@ class PagedAttentionWithALiBi(PagedAttention):
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
......@@ -364,6 +370,7 @@ class PagedAttentionWithALiBi(PagedAttention):
prompt_len,
padded_len,
device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment