Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44f26a94
Unverified
Commit
44f26a94
authored
Aug 16, 2024
by
Michael Goin
Committed by
GitHub
Aug 16, 2024
Browse files
[Model] Align nemotron config with final HF state and fix lm-eval-small (#7611)
parent
37fd47e7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
35 deletions
+29
-35
.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml
.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml
+4
-4
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+3
-3
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+3
-3
vllm/transformers_utils/configs/nemotron.py
vllm/transformers_utils/configs/nemotron.py
+18
-24
No files found.
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
→
.buildkite/lm-eval-harness/configs/Minitron-4B-Base
-FP8
.yaml
View file @
44f26a94
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m
nvidia
/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m
mgoin
/Minitron-4B-Base
-FP8
-b auto -l 1000 -f 5 -t 1
model_name
:
"
nvidia
/Minitron-4B-Base"
model_name
:
"
mgoin
/Minitron-4B-Base
-FP8
"
tasks
:
tasks
:
-
name
:
"
gsm8k"
-
name
:
"
gsm8k"
metrics
:
metrics
:
-
name
:
"
exact_match,strict-match"
-
name
:
"
exact_match,strict-match"
value
:
0.2
52
value
:
0.2
33
-
name
:
"
exact_match,flexible-extract"
-
name
:
"
exact_match,flexible-extract"
value
:
0.2
52
value
:
0.2
36
limit
:
1000
limit
:
1000
num_fewshot
:
5
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
44f26a94
...
@@ -4,7 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
...
@@ -4,7 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml
Minitron-4B-Base
-FP8
.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml
Meta-Llama-3-8B-QQQ.yaml
vllm/model_executor/layers/rotary_embedding.py
View file @
44f26a94
...
@@ -774,7 +774,7 @@ def get_rope(
...
@@ -774,7 +774,7 @@ def get_rope(
is_neox_style
:
bool
=
True
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
rotary_percent
:
float
=
1.0
,
partial_rotary_factor
:
float
=
1.0
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -787,8 +787,8 @@ def get_rope(
...
@@ -787,8 +787,8 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
else
:
rope_scaling_args
=
None
rope_scaling_args
=
None
if
rotary_percent
<
1.0
:
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
rotary_percent
)
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dtype
)
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
...
...
vllm/model_executor/models/nemotron.py
View file @
44f26a94
...
@@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...
@@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
# - There is no gate_proj, just up_proj
# - There is no gate_proj, just up_proj
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
# - Squared ReLU instead of SwiGLU
# - Squared ReLU instead of SwiGLU
# - Adds a
rotary_percent
to RoPE
# - Adds a
partial_rotary_factor
to RoPE
def
_cast_if_autocast_enabled
(
*
args
):
def
_cast_if_autocast_enabled
(
*
args
):
...
@@ -161,7 +161,7 @@ class NemotronAttention(nn.Module):
...
@@ -161,7 +161,7 @@ class NemotronAttention(nn.Module):
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
rotary_percent
=
config
.
rope_percent
self
.
partial_rotary_factor
=
config
.
partial_rotary_factor
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
...
@@ -187,7 +187,7 @@ class NemotronAttention(nn.Module):
...
@@ -187,7 +187,7 @@ class NemotronAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
rotary_percent
=
self
.
rotary_percent
,
partial_rotary_factor
=
self
.
partial_rotary_factor
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
...
vllm/transformers_utils/configs/nemotron.py
View file @
44f26a94
...
@@ -35,20 +35,20 @@ class NemotronConfig(PretrainedConfig):
...
@@ -35,20 +35,20 @@ class NemotronConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to
3
2000):
vocab_size (`int`, *optional*, defaults to 2
56
000):
Vocabulary size of the Nemotron model. Defines the number of
Vocabulary size of the Nemotron model. Defines the number of
different tokens that can be represented by the
different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`]
`inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to
4096
):
hidden_size (`int`, *optional*, defaults to
6144
):
Dimension of the hidden representations.
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to
11008
):
intermediate_size (`int`, *optional*, defaults to
24576
):
Dimension of the MLP representations.
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to
32
):
num_attention_heads (`int`, *optional*, defaults to
48
):
Number of attention heads for each attention layer in the
Number of attention heads for each attention layer in the
Transformer decoder.
Transformer decoder.
head_dim (`int`, *optional*
, defaults to None
):
head_dim (`int`, *optional*):
Projection weights dimension in multi-head attention. Set to
Projection weights dimension in multi-head attention. Set to
hidden_size // num_attention_heads if None
hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*):
num_key_value_heads (`int`, *optional*):
...
@@ -63,16 +63,16 @@ class NemotronConfig(PretrainedConfig):
...
@@ -63,16 +63,16 @@ class NemotronConfig(PretrainedConfig):
heads within that group. For more details checkout
heads within that group. For more details checkout
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
is not specified, will default to `num_attention_heads`.
is not specified, will default to `num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"
si
lu"`):
hidden_act (`str` or `function`, *optional*, defaults to `"
re
lu
2
"`):
The non-linear activation function (function or string) in the
The non-linear activation function (function or string) in the
decoder.
decoder.
max_position_embeddings (`int`, *optional*, defaults to
2048
):
max_position_embeddings (`int`, *optional*, defaults to
4096
):
The maximum sequence length that this model might ever be used
The maximum sequence length that this model might ever be used
with.
with.
initializer_range (`float`, *optional*, defaults to 0.0
2
):
initializer_range (`float`, *optional*, defaults to 0.0
134
):
The standard deviation of the truncated_normal_initializer for
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-0
6
):
norm_eps (`float`, *optional*, defaults to 1e-0
5
):
The epsilon used by the normalization layers.
The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
Whether or not the model should return the last key/values
...
@@ -80,21 +80,16 @@ class NemotronConfig(PretrainedConfig):
...
@@ -80,21 +80,16 @@ class NemotronConfig(PretrainedConfig):
`config.is_decoder=True`.
`config.is_decoder=True`.
pad_token_id (`int`, *optional*):
pad_token_id (`int`, *optional*):
Padding token id.
Padding token id.
bos_token_id (`int`, *optional*, defaults to
1
):
bos_token_id (`int`, *optional*, defaults to
2
):
Beginning of stream token id.
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to
2
):
eos_token_id (`int`, *optional*, defaults to
3
):
End of stream token id.
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Dictionary containing the scaling configuration for the RoPE
Percentage of the query and keys which will have rotary embedding.
embeddings. Currently supports two scaling strategies: linear
and dynamic. Their scaling factor must be a float greater than 1.
The expected format is `{"type": strategy name,
"factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, *optional*, defaults to `False`):
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output
Whether to use a bias in the query, key, value and output
projection layers during self-attention.
projection layers during self-attention.
...
@@ -106,13 +101,10 @@ class NemotronConfig(PretrainedConfig):
...
@@ -106,13 +101,10 @@ class NemotronConfig(PretrainedConfig):
```python
```python
>>> from transformers import NemotronModel, NemotronConfig
>>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig()
>>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration
>>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration)
>>> model = NemotronModel(configuration)
>>> # Accessing the model configuration
>>> # Accessing the model configuration
>>> configuration = model.config
>>> configuration = model.config
```"""
```"""
...
@@ -140,7 +132,7 @@ class NemotronConfig(PretrainedConfig):
...
@@ -140,7 +132,7 @@ class NemotronConfig(PretrainedConfig):
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
rope_scaling
=
None
,
rope_percent
=
0.5
,
partial_rotary_factor
=
0.5
,
attention_bias
=
False
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
mlp_bias
=
False
,
...
@@ -167,8 +159,10 @@ class NemotronConfig(PretrainedConfig):
...
@@ -167,8 +159,10 @@ class NemotronConfig(PretrainedConfig):
self
.
use_cache
=
use_cache
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
rope_scaling
=
rope_scaling
rope_percent
=
rope_percent
or
kwargs
.
get
(
"rope_percentage"
,
None
)
# for backward compatibility
self
.
rope_percent
=
rope_percent
partial_rotary_factor
=
kwargs
.
get
(
"rope_percent"
,
None
)
or
kwargs
.
get
(
"rope_percentage"
,
None
)
or
partial_rotary_factor
self
.
partial_rotary_factor
=
partial_rotary_factor
self
.
_rope_scaling_validation
()
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment