Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d37f9551
Unverified
Commit
d37f9551
authored
Feb 22, 2025
by
fzyzcjy
Committed by
GitHub
Feb 21, 2025
Browse files
Improve: Tiny fix Olmo2 (#3348)
parent
c66b2c9c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
python/sglang/srt/models/olmo2.py
python/sglang/srt/models/olmo2.py
+8
-8
No files found.
python/sglang/srt/models/olmo2.py
View file @
d37f9551
...
@@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module):
...
@@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
total_num_heads
%
tp_size
==
0
assert
self
.
total_num_heads
%
self
.
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
num_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
total_num_kv_heads
=
self
.
config
.
num_key_value_heads
self
.
total_num_kv_heads
=
self
.
config
.
num_key_value_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
if
self
.
total_num_kv_heads
>=
self
.
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
assert
self
.
total_num_kv_heads
%
self
.
tp_size
==
0
else
:
else
:
# Number of KV heads is less than TP size, so we replicate
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
assert
self
.
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
self
.
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
max_position_embeddings
=
config
.
max_position_embeddings
...
@@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module):
...
@@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module):
input_embeds
=
input_embeds
,
input_embeds
=
input_embeds
,
)
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment