Unverified Commit 05c094fc authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Consistently take `prefix` in model constructors (#2191)

* Consistently take `prefix` in model constructors

* Release test check fix

* Misc refactor-related fixes
parent 67ef0649
...@@ -153,7 +153,7 @@ jobs: ...@@ -153,7 +153,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
......
...@@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.opt_modeling import OPTForCau ...@@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.opt_modeling import OPTForCau
from text_generation_server.models.custom_modeling.mpt_modeling import ( from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM, MPTForCausalLM,
) )
from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import ( from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM, BloomForCausalLM,
) )
...@@ -522,7 +523,7 @@ def get_model( ...@@ -522,7 +523,7 @@ def get_model(
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast, batch_class=BloomCausalLMBatch,
) )
elif model_type == MPT: elif model_type == MPT:
return CausalLM( return CausalLM(
......
...@@ -553,7 +553,8 @@ class CausalLM(Model): ...@@ -553,7 +553,8 @@ class CausalLM(Model):
if config.quantize in ["awq", "exl2", "gptq", "marlin"]: if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = model_class(config, weights) prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super().__init__( super().__init__(
......
...@@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel): ...@@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config, weights) self.transformer = BloomModel(config, weights)
......
...@@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module): ...@@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module): class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig): def __init__(self, prefix: str, config: CLIPTextConfig):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
...@@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig): def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config) super().__init__(config)
self.text_model = CLIPTextTransformer(config) self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -363,9 +363,9 @@ class CohereMLP(nn.Module): ...@@ -363,9 +363,9 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module): class FlashCohereLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention( self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
...@@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module): ...@@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module):
class FlashCohereModel(torch.nn.Module): class FlashCohereModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashCohereLayer( FlashCohereLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
...@@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module): ...@@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module):
] ]
) )
self.norm = FastLayerNorm.load_no_bias( self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module): ...@@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module):
class FlashCohereForCausalLM(torch.nn.Module): class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = FlashCohereModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashCohereModel(prefix, config, weights)
try: try:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
...@@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module): ...@@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
except RuntimeError: except RuntimeError:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix=f"{prefix}.embed_tokens",
weights=weights, weights=weights,
) )
self.logit_scale = config.logit_scale self.logit_scale = config.logit_scale
......
...@@ -593,9 +593,9 @@ class DenseMoE(nn.Module): ...@@ -593,9 +593,9 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module): class DbrxLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.blocks.{layer_id}" prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm( self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
...@@ -637,16 +637,17 @@ class DbrxLayer(nn.Module): ...@@ -637,16 +637,17 @@ class DbrxLayer(nn.Module):
class DbrxModel(torch.nn.Module): class DbrxModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.wte", weights=weights prefix=f"{prefix}.wte", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DbrxLayer( DbrxLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
...@@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module): ...@@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module):
] ]
) )
self.norm = FastLayerNorm.load_no_bias( self.norm = FastLayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=1e-5 prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
) )
self.head_size = self.layers[0].attn.self_attn.head_size self.head_size = self.layers[0].attn.self_attn.head_size
...@@ -702,9 +703,14 @@ class DbrxModel(torch.nn.Module): ...@@ -702,9 +703,14 @@ class DbrxModel(torch.nn.Module):
class FlashDbrxForCausalLM(torch.nn.Module): class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(config, weights) self.model = DbrxModel(config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
......
...@@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig): ...@@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
class Gemma2FastRMSNorm(FastRMSNorm): class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
...@@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm): ...@@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
...@@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module): ...@@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module): class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemma2Attention( self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
...@@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module): ...@@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module):
class FlashGemma2Model(torch.nn.Module): class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
...@@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): ...@@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, *, causal: bool = True): def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5
......
...@@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig): ...@@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
...@@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm): ...@@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
...@@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module):
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
...@@ -299,7 +299,7 @@ class GemmaMLP(nn.Module): ...@@ -299,7 +299,7 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
...@@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module): ...@@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
...@@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, *, causal: bool = True): def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5
......
...@@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module):
class GPT2MLP(nn.Module): class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
act = config.activation_function act = config.activation_function
self.act = ( self.act = (
...@@ -298,7 +298,7 @@ class GPT2MLP(nn.Module): ...@@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module): class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.self_attn = FlashGPT2Attention( self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights prefix=f"{prefix}.attn", config=config, weights=weights
...@@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module): ...@@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module):
class FlashGPT2Model(torch.nn.Module): class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
...@@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module): ...@@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module):
class FlashGPT2ForCausalLM(torch.nn.Module): class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
......
...@@ -54,7 +54,7 @@ if SYSTEM == "rocm": ...@@ -54,7 +54,7 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights, layer_id): def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
...@@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
......
...@@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module): ...@@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.hidden_act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
...@@ -328,7 +328,7 @@ class MistralMLP(nn.Module): ...@@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
...@@ -392,7 +392,7 @@ class MistralLayer(nn.Module): ...@@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module): class MistralModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
...@@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module): ...@@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None): def __init__(self, prefix: str, config, weights, name=None):
if name is None: if name is None:
name = "model" name = "model"
super().__init__() super().__init__()
......
...@@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: ...@@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
...@@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights):
) )
def _load_experts(config, prefix, mat, weights): def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None: if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.") raise NotImplementedError("Mixtral does not support weight quantization yet.")
...@@ -475,7 +475,7 @@ class DenseMoE(nn.Module): ...@@ -475,7 +475,7 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"{prefix}.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
...@@ -536,7 +536,7 @@ class MixtralLayer(nn.Module): ...@@ -536,7 +536,7 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
...@@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module): ...@@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = MixtralModel(prefix, config, weights) self.model = MixtralModel(prefix, config, weights)
......
...@@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): ...@@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights prefix=f"{prefix}.embed_in", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
] ]
) )
self.final_layer_norm = FastLayerNorm.load( self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm", prefix=f"{prefix}.final_layer_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
...@@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
......
...@@ -258,9 +258,9 @@ class PhiMLP(nn.Module): ...@@ -258,9 +258,9 @@ class PhiMLP(nn.Module):
class FlashPhiLayer(nn.Module): class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention( self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
...@@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module): ...@@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module):
class FlashPhiModel(torch.nn.Module): class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashPhiLayer( FlashPhiLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
...@@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module): ...@@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module):
class FlashPhiForCausalLM(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashPhiModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
......
...@@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module): ...@@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
class Qwen2Layer(nn.Module): class Qwen2Layer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
...@@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module): ...@@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
class Qwen2Model(torch.nn.Module): class Qwen2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2Layer( Qwen2Layer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
...@@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module): ...@@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module): ...@@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
class Qwen2ForCausalLM(torch.nn.Module): class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = Qwen2Model(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Qwen2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
......
...@@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
prefix, prefix: str,
weights, weights,
): ):
super().__init__() super().__init__()
...@@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
prefix, prefix: str,
weights, weights,
): ):
super().__init__() super().__init__()
...@@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
...@@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module): ...@@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
layer_id, layer_id,
prefix: str,
config, config,
weights, weights,
): ):
...@@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module): ...@@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module):
parallel_attn = config.parallel_attn parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
...@@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module): ...@@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.num_ln = config.num_ln_in_parallel_attn self.num_ln = config.num_ln_in_parallel_attn
...@@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module): ...@@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, prefix: str, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
...@@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): ...@@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights prefix=f"{prefix}.word_embeddings", weights=weights
) )
if config.new_decoder_architecture: if config.new_decoder_architecture:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
else: else:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer(layer_id, config, weights) FlashRWLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = self.h[0].self_attention.num_heads_kv self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load( self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", prefix=f"{prefix}.ln_f",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
...@@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.transformer = FlashRWModel(config, weights) if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.transformer = FlashRWModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
......
...@@ -346,9 +346,9 @@ class MLP(nn.Module): ...@@ -346,9 +346,9 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.ln_1 = FastLayerNorm.load( self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
) )
...@@ -396,18 +396,18 @@ class Block(nn.Module): ...@@ -396,18 +396,18 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module): class FlashSantacoderModel(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.process_group = weights.process_group self.process_group = weights.process_group
self.wte = TensorParallelEmbedding( self.wte = TensorParallelEmbedding(
prefix="transformer.wte", prefix=f"{prefix}.wte",
weights=weights, weights=weights,
reduce=False, reduce=False,
) )
self.wpe = TensorParallelEmbedding( self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe", prefix=f"{prefix}.wpe",
weights=weights, weights=weights,
reduce=False, reduce=False,
) )
...@@ -415,6 +415,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -415,6 +415,7 @@ class FlashSantacoderModel(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Block( Block(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
...@@ -466,10 +467,16 @@ class FlashSantacoderModel(nn.Module): ...@@ -466,10 +467,16 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
self.model = FlashSantacoderModel(config, weights) self.model = FlashSantacoderModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix=f"{prefix}.wte", weights=weights
) )
def forward( def forward(
......
...@@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel): ...@@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
class MPTModel(MPTPreTrainedModel): class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
# config._validate_config() # config._validate_config()
super().__init__(config) super().__init__(config)
self.world_size = weights.process_group.size() self.world_size = weights.process_group.size()
...@@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel): ...@@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
f"Requested norm type ({config.norm_type}) is not implemented within this repo." f"Requested norm type ({config.norm_type}) is not implemented within this repo."
) )
self.wte = TensorParallelEmbedding("transformer.wte", weights) self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
if not self.alibi: if not self.alibi:
self.wpe = TensorParallelEmbedding("transformer.wpe", weights) self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
for i in range(config.n_layers) for i in range(config.n_layers)
] ]
) )
...@@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel): ...@@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
class MPTForCausalLM(MPTPreTrainedModel): class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings") raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights) self.transformer = MPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix=f"{prefix}.wte", weights=weights
) )
self.logit_scale = None self.logit_scale = None
if config.logit_scale is not None: if config.logit_scale is not None:
......
...@@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module): ...@@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, prefix: str, config, weights):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load( self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.post_attention_layernorm = nn.LayerNorm.load( self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.attention = GPTNeoXAttention( self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
) )
self.mlp = GPTNeoXMLP( self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
) )
def forward( def forward(
...@@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module): ...@@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(GPTNeoXPreTrainedModel): class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights prefix=f"{prefix}.embed_in", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
GPTNeoXLayer(layer_id, config, weights) GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm", prefix=f"{prefix}.final_layer_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
...@@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment