"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "1f01a18d39b7fc873b79024b5799597cb6fc88bc"
Commit 7d4ad51e authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bug fix

parent 37ae6646
...@@ -197,23 +197,23 @@ class ParallelSelfAttention(MegatronModule): ...@@ -197,23 +197,23 @@ class ParallelSelfAttention(MegatronModule):
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
# Attention heads [s, b, hp] --> [s, b, 3 * hp] # Attention heads [s, b, hp] --> [s, b, hp * 3]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
if self.old_checkpoint_format: if self.old_checkpoint_format:
self._transpose_last_dim(mixed_x_layer) # [s, b, 3 * hp] --> [s, b, hp * 3]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer)
# [s, b, 3 * hp] --> [s, b, np, 3 * hn] # [s, b, hp * 3] --> [s, b, np, hn, 3]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head) self.hidden_size_per_attention_head, 3)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn] # [s, b, np, hn, 3] --> 3 [s, b, np, hn]
(query_layer, query_layer = mixed_x_layer[:,:,:,:,0]
key_layer, key_layer = mixed_x_layer[:,:,:,:,1]
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer = mixed_x_layer[:,:,:,:,2]
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment