Unverified Commit d37f9551 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Improve: Tiny fix Olmo2 (#3348)

parent c66b2c9c
......@@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tp_size == 0
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = self.config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
if self.total_num_kv_heads >= self.tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
assert self.total_num_kv_heads % self.tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
......@@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module):
input_embeds=input_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment