Unverified Commit 02e6efe5 authored by r266-tech's avatar r266-tech Committed by GitHub
Browse files

[Bugfix] JAIS: Only apply ALiBi when position_embedding_type='alibi' (#37820)


Co-authored-by: default avatarr266-tech <r266-tech@users.noreply.github.com>
parent 410d3008
......@@ -117,11 +117,14 @@ class JAISAttention(nn.Module):
prefix=f"{prefix}.c_proj",
)
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.use_alibi = config.position_embedding_type == "alibi"
alibi_slopes = None
if self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(
self.num_heads,
self.head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment