Unverified Commit 93726b2a authored by lalit10's avatar lalit10 Committed by GitHub
Browse files

Refactor Arctic loading to use AutoWeightsLoader (#38955)


Signed-off-by: default avatarLalit Laxminarayan Bangad <lalitbangad@gmail.com>
Co-authored-by: default avatarLalit Laxminarayan Bangad <lalitbangad@meta.com>
parent 8617f867
...@@ -16,7 +16,6 @@ from vllm.distributed import ( ...@@ -16,7 +16,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
...@@ -42,6 +41,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig ...@@ -42,6 +41,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP, SupportsQuant from .interfaces import SupportsPP, SupportsQuant
from .utils import ( from .utils import (
AutoWeightsLoader,
extract_layer_index, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_empty_intermediate_tensors_factory,
...@@ -49,8 +49,6 @@ from .utils import ( ...@@ -49,8 +49,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
logger = init_logger(__name__)
class ArcticMLP(nn.Module): class ArcticMLP(nn.Module):
def __init__( def __init__(
...@@ -384,6 +382,7 @@ class ArcticModel(nn.Module): ...@@ -384,6 +382,7 @@ class ArcticModel(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
...@@ -426,57 +425,6 @@ class ArcticModel(nn.Module): ...@@ -426,57 +425,6 @@ class ArcticModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.model = ArcticModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -487,41 +435,26 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -487,41 +435,26 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
mlp_params_mapping: list[tuple[str, str, int]] = [] mlp_params_mapping: list[tuple[str, str, int]] = []
expert_params_mapping: list[tuple[str, str, int]] = [] expert_params_mapping: list[tuple[str, str, int]] = []
num_layers = self.config.num_hidden_layers
for layer in range(self.config.num_hidden_layers):
for layer in range(num_layers): is_moe_layer = (layer + 1) % self.config.moe_layer_frequency == 0
mlp_params_mapping.append( if is_moe_layer and self.config.use_residual:
(
f"layers.{layer}.residual_mlp.w13.weight",
f"layers.{layer}.residual_mlp.w1.weight",
0,
)
)
mlp_params_mapping.append(
(
f"layers.{layer}.residual_mlp.w13.weight",
f"layers.{layer}.residual_mlp.w3.weight",
1,
)
)
if layer % 2 == 0:
# MLP layers
mlp_params_mapping.append( mlp_params_mapping.append(
( (
f"layers.{layer}.block_sparse_moe.mlp.w13.weight", f"layers.{layer}.residual_mlp.w13.weight",
f"layers.{layer}.block_sparse_moe.mlp.w1.weight", f"layers.{layer}.residual_mlp.w1.weight",
0, 0,
) )
) )
mlp_params_mapping.append( mlp_params_mapping.append(
( (
f"layers.{layer}.block_sparse_moe.mlp.w13.weight", f"layers.{layer}.residual_mlp.w13.weight",
f"layers.{layer}.block_sparse_moe.mlp.w3.weight", f"layers.{layer}.residual_mlp.w3.weight",
1, 1,
) )
) )
else:
# MoE layers if is_moe_layer:
for expert_id in range(self.config.num_local_experts): for expert_id in range(self.config.num_local_experts):
expert_params_mapping.append( expert_params_mapping.append(
("ws", f"experts.{expert_id}.w1.weight", expert_id) ("ws", f"experts.{expert_id}.w1.weight", expert_id)
...@@ -532,15 +465,25 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -532,15 +465,25 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
expert_params_mapping.append( expert_params_mapping.append(
("ws", f"experts.{expert_id}.w3.weight", expert_id) ("ws", f"experts.{expert_id}.w3.weight", expert_id)
) )
else:
mlp_params_mapping.append(
(
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
0,
)
)
mlp_params_mapping.append(
(
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
1,
)
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
logger.info(
"It will take ~10 minutes loading from the 16-bit weights. "
"Alternatively, use the prequantized 8-bit weights of arctic "
"and set load-format to `sharded_state` will accelerate loading."
)
for name, loaded_weight in weights: for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -585,10 +528,67 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -585,10 +528,67 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.model = ArcticModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment