Unverified Commit d37f9551 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Improve: Tiny fix Olmo2 (#3348)

parent c66b2c9c
...@@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module): ...@@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0 assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = self.config.num_key_value_heads self.total_num_kv_heads = self.config.num_key_value_heads
if self.total_num_kv_heads >= tp_size: if self.total_num_kv_heads >= self.tp_size:
# Number of KV heads is greater than TP size, so we partition # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0 assert self.total_num_kv_heads % self.tp_size == 0
else: else:
# Number of KV heads is less than TP size, so we replicate # Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = self.hidden_size // self.total_num_heads self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
...@@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module): ...@@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module):
input_embeds=input_embeds, input_embeds=input_embeds,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment