Commit c53104e2 authored by pengcheng888's avatar pengcheng888
Browse files

调整repeat_kv的代码

parent 3ddffe8d
...@@ -29,31 +29,21 @@ logger = logging.get_logger(__name__) ...@@ -29,31 +29,21 @@ logger = logging.get_logger(__name__)
def repeat_kv(keys: infinicore.Tensor, values: infinicore.Tensor, ngroup: int): def repeat_kv(keys: infinicore.Tensor, values: infinicore.Tensor, ngroup: int):
total_seq_len, num_key_value_heads, head_dim = keys.shape total_seq_len, num_heads, head_dim = keys.shape
s0, s1, s2 = keys.stride()
keys_repeat = infinicore.empty( keys_new = (
(total_seq_len, num_key_value_heads, ngroup, head_dim), keys.as_strided((total_seq_len, num_heads, ngroup, head_dim), (s0, s1, 0, s2))
dtype=keys.dtype, .contiguous()
device=keys.device, .view((total_seq_len, num_heads * ngroup, head_dim))
)
values_repeat = infinicore.empty(
(total_seq_len, num_key_value_heads, ngroup, head_dim),
dtype=values.dtype,
device=values.device,
) )
for i in range(ngroup): values_new = (
keys_repeat.narrow(2, i, 1).copy_( values.as_strided((total_seq_len, num_heads, ngroup, head_dim), (s0, s1, 0, s2))
keys.view((total_seq_len, num_key_value_heads, 1, head_dim)) .contiguous()
) .view((total_seq_len, num_heads * ngroup, head_dim))
values_repeat.narrow(2, i, 1).copy_(
values.view((total_seq_len, num_key_value_heads, 1, head_dim))
)
keys_new = keys_repeat.view((total_seq_len, num_key_value_heads * ngroup, head_dim))
values_new = values_repeat.view(
(total_seq_len, num_key_value_heads * ngroup, head_dim)
) )
return keys_new, values_new return keys_new, values_new
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment