Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d48d7497
Commit
d48d7497
authored
Jul 01, 2025
by
zhuwenwen
Browse files
[Model] Add Ernie4.5 and Ernie4.5MoE
parent
5d3fb1d4
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1068 additions
and
0 deletions
+1068
-0
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+10
-0
tests/models/registry.py
tests/models/registry.py
+4
-0
vllm/model_executor/models/ernie45.py
vllm/model_executor/models/ernie45.py
+465
-0
vllm/model_executor/models/ernie45_moe.py
vllm/model_executor/models/ernie45_moe.py
+587
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
No files found.
docs/source/models/supported_models.md
View file @
d48d7497
...
@@ -315,6 +315,16 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -315,6 +315,16 @@ See [this page](#generative-models) for more information on how to use generativ
*
`deepseek-ai/DeepSeek-V3-Base`
,
`deepseek-ai/DeepSeek-V3`
etc.
*
`deepseek-ai/DeepSeek-V3-Base`
,
`deepseek-ai/DeepSeek-V3`
etc.
*
*
*
✅︎
*
✅︎
-
*
`Ernie4_5_ForCausalLM`
*
Ernie4.5
*
`baidu/ERNIE-4.5-0.3B-PT`
, etc.
*
*
✅︎
-
*
`Ernie4_5_MoeForCausalLM`
*
Ernie4.5MoE
*
`baidu/ERNIE-4.5-21B-A3B-PT`
,
`baidu/ERNIE-4.5-300B-A47B-PT`
, etc.
*
*
✅︎
-
*
`ExaoneForCausalLM`
-
*
`ExaoneForCausalLM`
*
EXAONE-3
*
EXAONE-3
*
`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`
, etc.
*
`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`
, etc.
...
...
tests/models/registry.py
View file @
d48d7497
...
@@ -259,6 +259,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -259,6 +259,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
tokenizer
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-2-7b"
),
tokenizer
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-2-7b"
),
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Zamba2ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Zyphra/Zamba2-7B-instruct"
)),
"Zamba2ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Zyphra/Zamba2-7B-instruct"
)),
"Ernie4_5_ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"baidu/ERNIE-4.5-0.3B-PT"
),
trust_remote_code
=
True
),
"Ernie4_5_MoeForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"baidu/ERNIE-4.5-21B-A3B-PT"
),
trust_remote_code
=
True
),
"MiMoForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"XiaomiMiMo/MiMo-7B-RL"
),
"MiMoForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"XiaomiMiMo/MiMo-7B-RL"
),
trust_remote_code
=
True
),
trust_remote_code
=
True
),
# [Encoder-decoder]
# [Encoder-decoder]
...
...
vllm/model_executor/models/ernie45.py
0 → 100644
View file @
d48d7497
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Erine model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
F
from
.interfaces
import
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
class
Ernie4_5_MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
use_bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
use_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
use_bias
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Ernie4_5_Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
Optional
[
int
]
=
None
,
rope_theta
:
float
=
500000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
131072
,
rms_norm_eps
:
float
=
1e-05
,
qkv_bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
layer_idx
=
extract_layer_index
(
prefix
)
if
len
(
prefix
)
>
0
else
0
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
or
(
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
is_neox_style
=
False
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
# Attention
attn_output
=
self
.
attn
(
q
,
k
,
v
)
# Output projection
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Ernie4_5_DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
500000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
131072
)
self
.
self_attn
=
Ernie4_5_Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
head_dim
=
getattr
(
config
,
'head_dim'
,
None
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
rms_norm_eps
=
config
.
rms_norm_eps
,
qkv_bias
=
getattr
(
config
,
'use_bias'
,
False
),
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
Ernie4_5_MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
use_bias
=
getattr
(
config
,
'use_bias'
,
False
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
class
Ernie4_5_Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
config
=
config
if
get_pp_group
().
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Ernie4_5_DecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Ernie4_5_ForCausalLM
(
nn
.
Module
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Ernie4_5_Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
PPMissingLayer
()
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
vllm/model_executor/models/ernie45_moe.py
0 → 100644
View file @
d48d7497
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
d48d7497
...
@@ -121,6 +121,8 @@ _TEXT_GENERATION_MODELS = {
...
@@ -121,6 +121,8 @@ _TEXT_GENERATION_MODELS = {
"TeleFLMForCausalLM"
:
(
"teleflm"
,
"TeleFLMForCausalLM"
),
"TeleFLMForCausalLM"
:
(
"teleflm"
,
"TeleFLMForCausalLM"
),
"XverseForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"XverseForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Zamba2ForCausalLM"
:
(
"zamba2"
,
"Zamba2ForCausalLM"
),
"Zamba2ForCausalLM"
:
(
"zamba2"
,
"Zamba2ForCausalLM"
),
"Ernie4_5_ForCausalLM"
:
(
"ernie45"
,
"Ernie4_5_ForCausalLM"
),
"Ernie4_5_MoeForCausalLM"
:
(
"ernie45_moe"
,
"Ernie4_5_MoeForCausalLM"
),
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment