Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19f0d257
Unverified
Commit
19f0d257
authored
Oct 02, 2024
by
Shawn Tan
Committed by
GitHub
Oct 03, 2024
Browse files
[Model] Adding Granite MoE. (#8206)
Co-authored-by:
Nick Hill
<
nickhill@us.ibm.com
>
parent
f58d4fcc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
492 additions
and
3 deletions
+492
-3
tests/models/decoder_only/language/test_granitemoe.py
tests/models/decoder_only/language/test_granitemoe.py
+39
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+4
-3
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+448
-0
No files found.
tests/models/decoder_only/language/test_granitemoe.py
0 → 100644
View file @
19f0d257
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import
pytest
from
...utils
import
check_logprobs_close
MODELS
=
[
"ibm/PowerMoE-3b"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
vllm/model_executor/models/__init__.py
View file @
19f0d257
...
@@ -32,6 +32,7 @@ _GENERATION_MODELS = {
...
@@ -32,6 +32,7 @@ _GENERATION_MODELS = {
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"GraniteForCausalLM"
:
(
"granite"
,
"GraniteForCausalLM"
),
"GraniteForCausalLM"
:
(
"granite"
,
"GraniteForCausalLM"
),
"GraniteMoeForCausalLM"
:
(
"granitemoe"
,
"GraniteMoeForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
...
...
vllm/model_executor/models/granite.py
View file @
19f0d257
...
@@ -404,9 +404,12 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
...
@@ -404,9 +404,12 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
if
hasattr
(
config
,
"logits_scaling"
):
logit_scale
/=
config
.
logits_scaling
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
config
.
vocab_size
,
logit_scale
)
scale
=
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
else
:
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
...
@@ -428,8 +431,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
...
@@ -428,8 +431,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
torch
.
Tensor
]:
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
if
logits
is
not
None
:
logits
/=
self
.
config
.
logits_scaling
return
logits
return
logits
def
sample
(
def
sample
(
...
...
vllm/model_executor/models/granitemoe.py
0 → 100644
View file @
19f0d257
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers.models.granitemoe
import
GraniteMoeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.
import
mixtral
from
.interfaces
import
SupportsLoRA
from
.utils
import
make_layers
class
GraniteMoeMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
expert across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
,
params_dtype
=
params_dtype
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
class
GraniteMoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attention_multiplier
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
(
attention_multiplier
if
attention_multiplier
is
not
None
else
self
.
head_dim
**-
1
)
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
GraniteMoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
GraniteMoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attention_multiplier
=
config
.
attention_multiplier
)
self
.
block_sparse_moe
=
GraniteMoeMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
residual_multiplier
=
config
.
residual_multiplier
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
*
self
.
residual_multiplier
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
block_sparse_moe
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
*
self
.
residual_multiplier
return
hidden_states
class
GraniteMoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
embedding_multiplier
=
config
.
embedding_multiplier
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
GraniteMoeDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
*=
self
.
embedding_multiplier
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
GraniteMoeForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
GraniteMoeModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
scale
=
1
/
self
.
config
.
logits_scaling
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
".block_sparse_moe.experts.%d.w1.weight"
%
e
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
".block_sparse_moe.experts.%d.w3.weight"
%
e
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
".block_sparse_moe.experts.%d.w2.weight"
%
e
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
elif
n
==
'lm_head.weight'
and
self
.
config
.
tie_word_embeddings
:
pass
else
:
new_weights
[
n
]
=
p
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment