Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07278c37
Unverified
Commit
07278c37
authored
Jul 26, 2024
by
Michael Goin
Committed by
GitHub
Jul 26, 2024
Browse files
[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) (#6611)
parent
85ad7e2d
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
776 additions
and
1 deletion
+776
-1
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+16
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+3
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+531
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+2
-1
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/nemotron.py
vllm/transformers_utils/configs/nemotron.py
+209
-0
No files found.
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
0 → 100644
View file @
07278c37
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
model_name
:
"
nvidia/Minitron-4B-Base"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.252
-
name
:
"
exact_match,flexible-extract"
value
:
0.252
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
07278c37
...
...
@@ -4,5 +4,6 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
vllm/model_executor/layers/activation.py
View file @
07278c37
...
...
@@ -159,6 +159,21 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
class
ReLUSquaredActivation
(
CustomOp
):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
relu_applied
=
nn
.
functional
.
relu
(
x
)
squared
=
torch
.
square
(
relu_applied
)
return
squared
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
x
)
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
...
...
@@ -207,6 +222,7 @@ _ACTIVATION_REGISTRY = {
"gelu_new"
:
NewGELU
(),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
"relu2"
:
ReLUSquaredActivation
(),
"quick_gelu"
:
QuickGELU
(),
}
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
07278c37
...
...
@@ -774,6 +774,7 @@ def get_rope(
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
rotary_percent
:
float
=
1.0
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -786,6 +787,8 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
rotary_percent
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
rotary_percent
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
...
...
vllm/model_executor/models/__init__.py
View file @
07278c37
...
...
@@ -51,6 +51,7 @@ _GENERATION_MODELS = {
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MiniCPMForCausalLM"
:
(
"minicpm"
,
"MiniCPMForCausalLM"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"NemotronForCausalLM"
:
(
"nemotron"
,
"NemotronForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
...
...
vllm/model_executor/models/nemotron.py
0 → 100644
View file @
07278c37
This diff is collapsed.
Click to expand it.
vllm/transformers_utils/config.py
View file @
07278c37
...
...
@@ -8,7 +8,7 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MedusaConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
RWConfig
)
NemotronConfig
,
RWConfig
)
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
...
...
@@ -26,6 +26,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"jais"
:
JAISConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"medusa"
:
MedusaConfig
,
"nemotron"
:
NemotronConfig
,
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
vllm/transformers_utils/configs/__init__.py
View file @
07278c37
...
...
@@ -8,6 +8,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig
from
vllm.transformers_utils.configs.medusa
import
MedusaConfig
from
vllm.transformers_utils.configs.mlp_speculator
import
MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.nemotron
import
NemotronConfig
__all__
=
[
"ChatGLMConfig"
,
...
...
@@ -17,4 +18,5 @@ __all__ = [
"JAISConfig"
,
"MedusaConfig"
,
"MLPSpeculatorConfig"
,
"NemotronConfig"
,
]
vllm/transformers_utils/configs/nemotron.py
0 → 100644
View file @
07278c37
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Nemotron model configuration"""
from
transformers
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
NemotronConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a
[`NemotronModel`]. It is used to instantiate an Nemotron model
according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the Nemotron-8B.
Configuration objects inherit from [`PretrainedConfig`] and can be
used to control the model outputs. Read the documentation from
[`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Nemotron model. Defines the number of
different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer decoder.
head_dim (`int`, *optional*, defaults to None):
Projection weights dimension in multi-head attention. Set to
hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use
Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention
(MQA) otherwise GQA is used. When converting a multi-head
checkpoint to a GQA checkpoint, each group key and value
head should be constructed by meanpooling all the original
heads within that group. For more details checkout
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
is not specified, will default to `num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the
decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used
with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE
embeddings. Currently supports two scaling strategies: linear
and dynamic. Their scaling factor must be a float greater than 1.
The expected format is `{"type": strategy name,
"factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output
projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj and down_proj layers in the MLP
layers.
```python
>>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"nemotron"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
256000
,
hidden_size
=
6144
,
intermediate_size
=
24576
,
num_hidden_layers
=
32
,
num_attention_heads
=
48
,
head_dim
=
None
,
num_key_value_heads
=
None
,
hidden_act
=
"relu2"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.0134
,
norm_eps
=
1e-5
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
2
,
eos_token_id
=
3
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
rope_percent
=
0.5
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
head_dim
=
head_dim
or
kwargs
.
get
(
"kv_channels"
,
None
)
self
.
head_dim
=
head_dim
if
head_dim
is
not
None
else
(
hidden_size
//
num_attention_heads
)
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
norm_eps
=
norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
rope_percent
=
rope_percent
or
kwargs
.
get
(
"rope_percentage"
,
None
)
self
.
rope_percent
=
rope_percent
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
mlp_bias
=
mlp_bias
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with two fields, "
f
"`type` and `factor`, got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
"`rope_scaling`'s type field must be one of ['linear', "
f
"'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
"`rope_scaling`'s factor field must be a float > 1, got "
f
"
{
rope_scaling_factor
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment