Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ade15dd
Unverified
Commit
4ade15dd
authored
Nov 08, 2024
by
aqweteddy
Committed by
GitHub
Nov 08, 2024
Browse files
Adjust reward model's score module and pooler module order for reducing computation (#1956)
parent
8dc84da0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
60 deletions
+8
-60
python/sglang/srt/models/gemma2_reward.py
python/sglang/srt/models/gemma2_reward.py
+3
-36
python/sglang/srt/models/llama_reward.py
python/sglang/srt/models/llama_reward.py
+5
-24
No files found.
python/sglang/srt/models/gemma2_reward.py
View file @
4ade15dd
...
@@ -58,43 +58,10 @@ class Gemma2ForSequenceClassification(nn.Module):
...
@@ -58,43 +58,10 @@ class Gemma2ForSequenceClassification(nn.Module):
),
"Gemma2ForSequenceClassification is only used for embedding"
),
"Gemma2ForSequenceClassification is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
scores
=
self
.
score
(
hidden_states
)
last_token_hidden
=
self
.
pooler
(
hidden_states
,
forward_batch
).
embeddings
scores
=
self
.
score
(
last_token_hidden
)
return
self
.
pooler
(
scores
,
forward_batch
)
return
EmbeddingPoolerOutput
(
scores
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
Gemma2ForCausalLM
.
load_weights
(
self
,
weights
)
Gemma2ForCausalLM
.
load_weights
(
self
,
weights
)
...
...
python/sglang/srt/models/llama_reward.py
View file @
4ade15dd
...
@@ -59,22 +59,13 @@ class LlamaForSequenceClassification(nn.Module):
...
@@ -59,22 +59,13 @@ class LlamaForSequenceClassification(nn.Module):
),
"LlamaForSequenceClassification is only used for embedding"
),
"LlamaForSequenceClassification is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
scores
=
self
.
score
(
hidden_states
)
last_token_hidden
=
self
.
pooler
(
hidden_states
,
forward_batch
).
embeddings
scores
=
self
.
score
(
last_token_hidden
)
return
self
.
pooler
(
scores
,
forward_batch
)
return
EmbeddingPoolerOutput
(
scores
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
return
LlamaForCausalLM
.
load_weights
(
self
,
weights
)
for
name
,
loaded_weight
in
weights
:
if
"classification_head"
in
name
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
elif
"lm_head"
in
name
:
continue
else
:
LlamaForCausalLM
.
load_weights
(
self
,
[(
name
,
loaded_weight
)])
class
LlamaForSequenceClassificationWithNormal_Weights
(
LlamaForSequenceClassification
):
class
LlamaForSequenceClassificationWithNormal_Weights
(
LlamaForSequenceClassification
):
...
@@ -127,17 +118,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
...
@@ -127,17 +118,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
return
EmbeddingPoolerOutput
(
scores
)
return
EmbeddingPoolerOutput
(
scores
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
return
super
().
load_weights
(
weights
)
for
name
,
loaded_weight
in
weights
:
if
"classification_head"
in
name
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
elif
"lm_head"
in
name
:
continue
else
:
LlamaForCausalLM
.
load_weights
(
self
,
[(
name
,
loaded_weight
)])
EntryClass
=
[
EntryClass
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment