Unverified Commit 4ade15dd authored by aqweteddy's avatar aqweteddy Committed by GitHub
Browse files

Adjust reward model's score module and pooler module order for reducing computation (#1956)

parent 8dc84da0
...@@ -58,43 +58,10 @@ class Gemma2ForSequenceClassification(nn.Module): ...@@ -58,43 +58,10 @@ class Gemma2ForSequenceClassification(nn.Module):
), "Gemma2ForSequenceClassification is only used for embedding" ), "Gemma2ForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states) last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.score(last_token_hidden)
return self.pooler(scores, forward_batch) return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Gemma2ForCausalLM.load_weights(self, weights) Gemma2ForCausalLM.load_weights(self, weights)
......
...@@ -59,22 +59,13 @@ class LlamaForSequenceClassification(nn.Module): ...@@ -59,22 +59,13 @@ class LlamaForSequenceClassification(nn.Module):
), "LlamaForSequenceClassification is only used for embedding" ), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states) last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.score(last_token_hidden)
return self.pooler(scores, forward_batch) return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) return LlamaForCausalLM.load_weights(self, weights)
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification): class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
...@@ -127,17 +118,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific ...@@ -127,17 +118,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
return EmbeddingPoolerOutput(scores) return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) return super().load_weights(weights)
for name, loaded_weight in weights:
if "classification_head" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
EntryClass = [ EntryClass = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment