Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7462218e
Commit
7462218e
authored
Sep 05, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.0-dtk24.04.1'
parents
6ccd3f47
1cec5e62
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1454 additions
and
149 deletions
+1454
-149
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+13
-3
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+159
-61
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+27
-10
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+24
-52
vllm/model_executor/layers/ops/rand.py
vllm/model_executor/layers/ops/rand.py
+9
-2
vllm/model_executor/layers/ops/sample.py
vllm/model_executor/layers/ops/sample.py
+9
-2
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+63
-13
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+139
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+1
-1
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+16
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+51
-0
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+57
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+534
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+94
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+99
-0
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+99
-0
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+29
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+28
-0
No files found.
vllm/model_executor/layers/activation.py
View file @
7462218e
...
@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
import
vllm.envs
as
envs
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
...
@@ -34,7 +35,10 @@ class SiluAndMul(CustomOp):
...
@@ -34,7 +35,10 @@ class SiluAndMul(CustomOp):
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
silu_and_mul
(
out
,
x
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
silu_and_mul_opt
(
out
,
x
)
else
:
ops
.
silu_and_mul
(
out
,
x
)
return
out
return
out
...
@@ -66,9 +70,15 @@ class GeluAndMul(CustomOp):
...
@@ -66,9 +70,15 @@ class GeluAndMul(CustomOp):
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
approximate
==
"none"
:
if
self
.
approximate
==
"none"
:
ops
.
gelu_and_mul
(
out
,
x
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
gelu_and_mul_opt
(
out
,
x
)
else
:
ops
.
gelu_and_mul
(
out
,
x
)
elif
self
.
approximate
==
"tanh"
:
elif
self
.
approximate
==
"tanh"
:
ops
.
gelu_tanh_and_mul
(
out
,
x
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
gelu_tanh_and_mul_opt
(
out
,
x
)
else
:
ops
.
gelu_tanh_and_mul
(
out
,
x
)
return
out
return
out
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
7462218e
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
)
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
__all__
=
[
__all__
=
[
"fused_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_topk"
,
"fused_experts"
,
"fused_experts"
,
"get_config_file_name"
,
"get_config_file_name"
,
"grouped_topk"
,
]
]
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
7462218e
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -331,6 +332,31 @@ def get_default_config(
...
@@ -331,6 +332,31 @@ def get_default_config(
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
E
,
_
,
N
=
w2_shape
configs
=
get_moe_configs
(
E
,
N
,
dtype
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
return
config
def
fused_topk
(
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
@@ -367,6 +393,39 @@ def fused_topk(
...
@@ -367,6 +393,39 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -389,25 +448,23 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -389,25 +448,23 @@ def fused_experts(hidden_states: torch.Tensor,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
]
M
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
)
if
override_config
:
config
=
get_config_func
(
M
)
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
...
@@ -419,51 +476,78 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -419,51 +476,78 @@ def fused_experts(hidden_states: torch.Tensor,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
if
inplace
:
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
=
hidden_states
dim
=
1
,
else
:
out
=
hidden_states
)
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
def
fused_moe
(
def
fused_moe
(
...
@@ -475,6 +559,9 @@ def fused_moe(
...
@@ -475,6 +559,9 @@ def fused_moe(
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -497,6 +584,10 @@ def fused_moe(
...
@@ -497,6 +584,10 @@ def fused_moe(
Defaults to False.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...
@@ -510,8 +601,15 @@ def fused_moe(
...
@@ -510,8 +601,15 @@ def fused_moe(
# Check constraints.
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
if
use_grouped_topk
:
renormalize
)
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
return
fused_experts
(
hidden_states
,
w1
,
w1
,
w2
,
w2
,
...
@@ -523,4 +621,4 @@ def fused_moe(
...
@@ -523,4 +621,4 @@ def fused_moe(
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
a2_scale
=
a2_scale
)
\ No newline at end of file
vllm/model_executor/layers/layernorm.py
View file @
7462218e
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
import
vllm.envs
as
envs
class
RMSNorm
(
CustomOp
):
class
RMSNorm
(
CustomOp
):
...
@@ -51,20 +52,36 @@ class RMSNorm(CustomOp):
...
@@ -51,20 +52,36 @@ class RMSNorm(CustomOp):
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
if
residual
is
not
None
:
if
residual
is
not
None
:
ops
.
fused_add_rms_norm
(
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
fused_add_rms_norm_opt
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
else
:
ops
.
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
rms_norm_opt
(
out
,
x
,
x
,
residual
,
self
.
weight
.
data
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
self
.
variance_epsilon
,
)
)
return
x
,
residual
else
:
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
ops
.
rms_norm
(
out
,
out
,
x
,
x
,
self
.
weight
.
data
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
self
.
variance_epsilon
,
)
)
return
out
return
out
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
vllm/model_executor/layers/linear.py
View file @
7462218e
...
@@ -15,8 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -15,8 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.logger
import
init_logger
import
os
import
os
from
vllm.model_executor.utils
import
gemm_bank_conf
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -89,7 +89,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -89,7 +89,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
...
@@ -108,17 +109,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -108,17 +109,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
return
F
.
linear
(
x
,
weight
)
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
weight
=
weight
.
reshape
(
weight
.
shape
[
1
],
-
1
)
if
gemm_bank_conf
(
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
weight
=
weight
[:,:
-
32
]
if
bias
is
not
None
:
if
bias
is
not
None
:
return
torch
.
matmul
(
x
,
weight
)
+
bias
if
len
(
x
.
shape
)
==
2
:
return
torch
.
addmm
(
bias
,
x
,
weight
)
else
:
return
torch
.
matmul
(
x
,
weight
)
+
bias
else
:
else
:
return
torch
.
matmul
(
x
,
weight
)
return
torch
.
matmul
(
x
,
weight
)
else
:
else
:
return
F
.
linear
(
x
,
weight
,
bias
)
return
F
.
linear
(
x
,
weight
,
bias
)
...
@@ -279,7 +286,6 @@ class ColumnParallelLinear(LinearBase):
...
@@ -279,7 +286,6 @@ class ColumnParallelLinear(LinearBase):
})
})
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
# Special case for Fp8 scales.
...
@@ -301,9 +307,6 @@ class ColumnParallelLinear(LinearBase):
...
@@ -301,9 +307,6 @@ class ColumnParallelLinear(LinearBase):
shard_id
=
0
)
shard_id
=
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
param_data
.
shape
[
0
],
-
1
)
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -368,8 +371,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -368,8 +371,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -448,21 +449,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -448,21 +449,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
if
use_bitsandbytes
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
loaded_shard_id
if
self
.
use_llama_nn
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data_
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
...
@@ -498,17 +493,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -498,17 +493,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
len
(
loaded_weight
.
shape
)
==
0
:
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
param_data_
.
copy_
(
loaded_weight
)
if
loaded_shard_id
==
1
and
len
(
param_data
.
shape
)
==
2
:
param_data
=
param_data
.
transpose
(
0
,
1
)
param
.
data
=
param_data
.
reshape
(
param_data
.
shape
[
1
],
-
1
)
else
:
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
class
QKVParallelLinear
(
ColumnParallelLinear
):
...
@@ -568,6 +555,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -568,6 +555,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
]
]
super
().
__init__
(
input_size
=
input_size
,
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
output_size
=
output_size
,
bias
=
bias
,
bias
=
bias
,
...
@@ -575,7 +563,6 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -575,7 +563,6 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -683,14 +670,9 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -683,14 +670,9 @@ class QKVParallelLinear(ColumnParallelLinear):
}
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
self
.
use_llama_nn
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data_
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
shard_id
=
tp_rank
else
:
else
:
...
@@ -722,21 +704,15 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -722,21 +704,15 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
"for all partitions."
)
if
len
(
param_data
.
shape
)
==
0
:
if
len
(
param_data
.
shape
)
==
0
:
param_data
=
param_data
.
reshape
(
1
)
param_data
=
param_data
.
reshape
(
1
)
if
len
(
loaded_weight
.
shape
)
==
0
:
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data_
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data_
.
copy_
(
loaded_weight
)
if
loaded_shard_id
==
"v"
and
len
(
param_data
.
shape
)
==
2
:
param_data
=
param_data
.
transpose
(
0
,
1
)
param
.
data
=
param_data
.
reshape
(
param_data
.
shape
[
1
],
-
1
)
else
:
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
class
RowParallelLinear
(
LinearBase
):
...
@@ -805,7 +781,6 @@ class RowParallelLinear(LinearBase):
...
@@ -805,7 +781,6 @@ class RowParallelLinear(LinearBase):
})
})
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
# Special case for Fp8 scales.
...
@@ -831,9 +806,6 @@ class RowParallelLinear(LinearBase):
...
@@ -831,9 +806,6 @@ class RowParallelLinear(LinearBase):
loaded_weight
=
loaded_weight
.
reshape
(
1
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
param_data
.
shape
[
0
],
-
1
)
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
...
vllm/model_executor/layers/ops/rand.py
View file @
7462218e
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.utils
import
is_hip
def
seeded_uniform
(
def
seeded_uniform
(
...
@@ -69,9 +70,15 @@ def seeded_uniform(
...
@@ -69,9 +70,15 @@ def seeded_uniform(
# Manual tuning. This seems to give best performance on A100 for
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
# simple kernels like this.
if
philox_block_size
>=
8192
:
if
philox_block_size
>=
8192
:
num_warps
=
32
if
is_hip
():
num_warps
=
16
else
:
num_warps
=
32
elif
philox_block_size
>=
4096
:
elif
philox_block_size
>=
4096
:
num_warps
=
16
if
is_hip
():
num_warps
=
8
else
:
num_warps
=
16
elif
philox_block_size
>=
2048
:
elif
philox_block_size
>=
2048
:
num_warps
=
8
num_warps
=
8
...
...
vllm/model_executor/layers/ops/sample.py
View file @
7462218e
...
@@ -6,6 +6,7 @@ import triton
...
@@ -6,6 +6,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.utils
import
is_hip
_EPS
=
1e-6
_EPS
=
1e-6
...
@@ -278,9 +279,15 @@ def _sample(probs: torch.Tensor,
...
@@ -278,9 +279,15 @@ def _sample(probs: torch.Tensor,
# Manual tuning. This seems to give best performance on A100 for
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
# simple kernels like this.
if
block_size
>=
8192
:
if
block_size
>=
8192
:
num_warps
=
32
if
is_hip
():
num_warps
=
16
else
:
num_warps
=
32
elif
block_size
>=
4096
:
elif
block_size
>=
4096
:
num_warps
=
16
if
is_hip
():
num_warps
=
8
else
:
num_warps
=
16
elif
block_size
>=
2048
:
elif
block_size
>=
2048
:
num_warps
=
8
num_warps
=
8
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
7462218e
...
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
...
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
...
@@ -9,6 +10,19 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -9,6 +10,19 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
class
AWQShareWorkSpace
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
AWQShareWorkSpace
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
# 执行初始化
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
self
.
awqworkshapcesize
=
ops
.
GetAWQShareWorkspaceSize
()
self
.
awqworkshapce
=
ops
.
GetAWQShareWorkspace
()
class
AWQConfig
(
QuantizationConfig
):
class
AWQConfig
(
QuantizationConfig
):
"""Config class for AWQ.
"""Config class for AWQ.
...
@@ -81,6 +95,7 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -81,6 +95,7 @@ class AWQLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
awqsingleton
=
AWQShareWorkSpace
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
...
@@ -142,6 +157,19 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -142,6 +157,19 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim"
:
0
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"output_dim"
:
1
,
})
})
zeros_and_scales
=
Parameter
(
torch
.
empty
(
(
input_size_per_partition
//
self
.
quant_config
.
group_size
),
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
zeros_and_scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
...
@@ -149,27 +177,49 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -149,27 +177,49 @@ class AWQLinearMethod(LinearMethodBase):
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"zeros_and_scales"
,
zeros_and_scales
)
set_weight_attrs
(
zeros_and_scales
,
extra_weight_attrs
)
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight
=
layer
.
qweight
scales
=
layer
.
scales
zeros_and_scales
=
layer
.
zeros_and_scales
qzeros
=
layer
.
qzeros
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
0
]
*
1
,
))
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,
))
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# num_tokens >= threshold
m
=
reshaped_x
.
shape
[
0
]
FP16_MATMUL_HEURISTIC_CONDITION
=
x
.
shape
[
:
-
1
]
.
numel
()
>=
256
k
=
reshaped_
x
.
shape
[
-
1
]
n
=
qweight
.
shape
[
0
]
if
FP16_MATMUL_HEURISTIC_CONDITION
:
out
=
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
if
k
%
4096
==
0
:
out
=
torch
.
matmul
(
reshaped_x
,
out
)
padding_group
=
2
else
:
else
:
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
padding_group
=
0
pack_factor
)
if
m
<
4096
:
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
zeros_and_scales
,
m
,
n
,
k
,
self
.
quant_config
.
group_size
,
padding_group
,
self
.
awqsingleton
.
awqworkshapce
,
self
.
awqsingleton
.
awqworkshapcesize
)
else
:
#下面是采用rocblas的做法
deqweight
=
ops
.
dequant_w4_gemm_colmajor
(
#shape[n,k/8]--->[n,k]
qweight
,
zeros_and_scales
,
k
,
n
,
self
.
quant_config
.
group_size
)
out
=
F
.
linear
(
reshaped_x
,
deqweight
[:,
0
:
k
])
if
bias
is
not
None
:
if
bias
is
not
None
:
out
.
add_
(
bias
)
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
vllm/model_executor/layers/rotary_embedding.py
View file @
7462218e
...
@@ -576,6 +576,129 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -576,6 +576,129 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
print
(
"Cache shape"
,
cache
.
shape
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
self
.
rotary_dim
))
return
inv_freq
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -633,7 +756,22 @@ def get_rope(
...
@@ -633,7 +756,22 @@ def get_rope(
base
,
is_neox_style
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
elif
scaling_type
==
"deepseek_yarn"
:
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif
scaling_type
==
"su"
or
scaling_type
==
"longrope"
:
short_factor
=
rope_scaling
[
"short_factor"
]
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
original_max_position
=
rope_scaling
[
...
...
vllm/model_executor/model_loader/loader.py
View file @
7462218e
...
@@ -274,7 +274,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -274,7 +274,7 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
quant_method
is
not
None
and
quant_method
!=
"awq"
and
quant_method
!=
"gptq"
:
quant_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
# to use quant_method.
...
...
vllm/model_executor/model_loader/utils.py
View file @
7462218e
...
@@ -22,9 +22,24 @@ def set_default_torch_dtype(dtype: torch.dtype):
...
@@ -22,9 +22,24 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
if
architectures
==
[
'LlamaForCausalLM'
]
or
architectures
==
[
'ChatGLMModel'
]
or
architectures
==
[
'BaichuanForCausalLM'
]:
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
if
any
(
arch
in
architectures
for
arch
in
use_triton_fa_architectures
):
os
.
environ
[
'VLLM_USE_TRITON_FLASH_ATTN'
]
=
'1'
os
.
environ
[
'VLLM_USE_FLASH_ATTN_AUTO'
]
=
'0'
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
...
...
vllm/model_executor/models/__init__.py
View file @
7462218e
...
@@ -21,6 +21,7 @@ _GENERATION_MODELS = {
...
@@ -21,6 +21,7 @@ _GENERATION_MODELS = {
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
...
...
vllm/model_executor/models/baichuan.py
View file @
7462218e
...
@@ -24,6 +24,8 @@ from typing import Iterable, List, Optional, Tuple
...
@@ -24,6 +24,8 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
@@ -45,6 +47,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -45,6 +47,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
...
@@ -169,6 +174,11 @@ class BaiChuanAttention(nn.Module):
...
@@ -169,6 +174,11 @@ class BaiChuanAttention(nn.Module):
self
.
scaling
,
self
.
scaling
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -178,6 +188,8 @@ class BaiChuanAttention(nn.Module):
...
@@ -178,6 +188,8 @@ class BaiChuanAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
@@ -326,6 +338,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -326,6 +338,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -393,6 +414,36 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -393,6 +414,36 @@ class BaiChuanBaseForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attn.W_pack.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.W_pack.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
vllm/model_executor/models/chatglm.py
View file @
7462218e
...
@@ -7,6 +7,8 @@ from typing import Iterable, List, Optional, Tuple
...
@@ -7,6 +7,8 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
@@ -28,6 +30,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -28,6 +30,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
GLMAttention
(
nn
.
Module
):
class
GLMAttention
(
nn
.
Module
):
...
@@ -92,6 +97,11 @@ class GLMAttention(nn.Module):
...
@@ -92,6 +97,11 @@ class GLMAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -101,6 +111,8 @@ class GLMAttention(nn.Module):
...
@@ -101,6 +111,8 @@ class GLMAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
context_layer
=
self
.
attn
(
context_layer
=
self
.
attn
(
...
@@ -353,6 +365,15 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -353,6 +365,15 @@ class ChatGLMForCausalLM(nn.Module):
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -393,3 +414,39 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -393,3 +414,39 @@ class ChatGLMForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attention.query_key_value.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_4h_to_h.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attention.query_key_value.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"self_attention.query_key_value.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
vllm/model_executor/models/deepseek_v2.py
0 → 100644
View file @
7462218e
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
grouped_topk
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
top_k
=
config
.
num_experts_per_tok
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
if
self
.
tp_size
>
self
.
n_routed_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
n_routed_experts
}
."
)
self
.
experts
=
nn
.
ModuleList
([
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
self
.
pack_params
()
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
def
pack_params
(
self
):
w1
=
[]
w2
=
[]
for
expert
in
self
.
experts
:
w1
.
append
(
expert
.
gate_up_proj
.
weight
)
w2
.
append
(
expert
.
down_proj
.
weight
)
self
.
w1
=
torch
.
_utils
.
_flatten_dense_tensors
(
w1
)
w1s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w1
,
w1
)
for
data
,
param
in
zip
(
w1s
,
w1
):
param
.
data
=
data
self
.
w1
=
self
.
w1
.
view
(
len
(
w1
),
*
w1s
[
0
].
shape
)
self
.
w2
=
torch
.
_utils
.
_flatten_dense_tensors
(
w2
)
w2s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w2
,
w2
)
for
data
,
param
in
zip
(
w2s
,
w2
):
param
.
data
=
data
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
config
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
num_expert_group
=
self
.
config
.
n_group
,
topk_group
=
self
.
config
.
topk_group
)
final_hidden_states
=
fused_experts
(
hidden_states
,
self
.
w1
,
self
.
w2
,
topk_weights
,
topk_ids
,
inplace
=
True
)
*
self
.
routed_scaling_factor
if
self
.
config
.
n_shared_experts
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
import
math
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekV2Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
,
kv_lora_rank
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_idx
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
)
# O projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
rope_scaling
[
'type'
]
=
'deepseek_yarn'
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
)
if
rope_scaling
:
mscale_all_dim
=
rope_scaling
.
get
(
"mscale_all_dim"
,
False
)
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self
.
attn
=
Attention
(
self
.
num_local_heads
,
256
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
kv_a
,
_
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
.
contiguous
())
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
q
=
torch
.
nn
.
functional
.
pad
(
q
,
[
0
,
256
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
k
=
torch
.
nn
.
functional
.
pad
(
k
,
[
0
,
256
-
self
.
qk_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
256
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
256
)[...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
DeepseekV2Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
q_lora_rank
=
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
,
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_idx
=
layer_idx
,
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
quant_config
=
quant_config
)
else
:
self
.
mlp
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
DeepseekV2Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekV2DecoderLayer
(
config
,
layer_idx
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
DeepseekV2ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekV2Model
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/llama.py
View file @
7462218e
...
@@ -26,6 +26,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
...
@@ -26,6 +26,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
@@ -49,6 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -49,6 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
...
@@ -147,6 +152,11 @@ class LlamaAttention(nn.Module):
...
@@ -147,6 +152,11 @@ class LlamaAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -156,6 +166,8 @@ class LlamaAttention(nn.Module):
...
@@ -156,6 +166,8 @@ class LlamaAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
@@ -360,6 +372,15 @@ class LlamaForCausalLM(nn.Module):
...
@@ -360,6 +372,15 @@ class LlamaForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -435,8 +456,79 @@ class LlamaForCausalLM(nn.Module):
...
@@ -435,8 +456,79 @@ class LlamaForCausalLM(nn.Module):
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.qkv_proj.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
quant_method
==
"awq"
:
lay_key_words
=
[
"self_attn.qkv_proj.qweight"
,
"self_attn.o_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.down_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
zeros_and_scalse
=
params_dict
[
layername
.
replace
(
"qweight"
,
"zeros_and_scales"
)]
group_size
=
self
.
quant_config
.
group_size
dim_n
=
scales
.
data
.
shape
[
1
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
_qw
,
_sz
=
ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
sz
=
ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
#reshape
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
# If this function is called, it should always initialize KV cache scale
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
# make sure to leave KV cache scale factors in a known good (dummy) state
...
...
vllm/model_executor/models/qwen.py
View file @
7462218e
...
@@ -10,6 +10,9 @@ import torch
...
@@ -10,6 +10,9 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
@@ -29,6 +32,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -29,6 +32,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
QWenMLP
(
nn
.
Module
):
class
QWenMLP
(
nn
.
Module
):
...
@@ -108,6 +114,11 @@ class QWenAttention(nn.Module):
...
@@ -108,6 +114,11 @@ class QWenAttention(nn.Module):
self
.
scaling
,
self
.
scaling
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -117,6 +128,8 @@ class QWenAttention(nn.Module):
...
@@ -117,6 +128,8 @@ class QWenAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
@@ -237,6 +250,15 @@ class QWenLMHeadModel(nn.Module):
...
@@ -237,6 +250,15 @@ class QWenLMHeadModel(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -292,3 +314,80 @@ class QWenLMHeadModel(nn.Module):
...
@@ -292,3 +314,80 @@ class QWenLMHeadModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"attn.c_attn.weight"
,
"attn.c_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.c_proj.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"attn.c_attn.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"attn.c_attn.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
quant_method
==
"awq"
:
lay_key_words
=
[
"attn.c_attn.qweight"
,
"attn.c_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.c_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
zeros_and_scalse
=
params_dict
[
layername
.
replace
(
"qweight"
,
"zeros_and_scales"
)]
group_size
=
self
.
quant_config
.
group_size
dim_n
=
scales
.
data
.
shape
[
1
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
_qw
,
_sz
=
ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
sz
=
ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
#reshape
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
vllm/model_executor/models/qwen2.py
View file @
7462218e
...
@@ -27,6 +27,8 @@ from typing import Iterable, List, Optional, Tuple
...
@@ -27,6 +27,8 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Qwen2Config
from
transformers
import
Qwen2Config
import
os
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
@@ -47,6 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -47,6 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
class
Qwen2MLP
(
nn
.
Module
):
class
Qwen2MLP
(
nn
.
Module
):
...
@@ -139,6 +144,11 @@ class Qwen2Attention(nn.Module):
...
@@ -139,6 +144,11 @@ class Qwen2Attention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -148,6 +158,8 @@ class Qwen2Attention(nn.Module):
...
@@ -148,6 +158,8 @@ class Qwen2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
...
@@ -319,6 +331,15 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -319,6 +331,15 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -379,3 +400,81 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -379,3 +400,81 @@ class Qwen2ForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attn.qkv_proj.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
lay_qkv_bias_words
=
[
"self_attn.qkv_proj.bias"
]
qkv_bias_words
=
"|"
.
join
(
lay_qkv_bias_words
)
for
layername
,
weight
in
params_dict
.
items
():
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_bias_words
,
layername
)):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
quant_method
==
"awq"
:
lay_key_words
=
[
"self_attn.qkv_proj.qweight"
,
"self_attn.o_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.down_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
zeros_and_scalse
=
params_dict
[
layername
.
replace
(
"qweight"
,
"zeros_and_scales"
)]
group_size
=
self
.
quant_config
.
group_size
dim_n
=
scales
.
data
.
shape
[
1
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
_qw
,
_sz
=
ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
sz
=
ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
#reshape
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
\ No newline at end of file
vllm/model_executor/utils.py
View file @
7462218e
...
@@ -33,3 +33,32 @@ def set_weight_attrs(
...
@@ -33,3 +33,32 @@ def set_weight_attrs(
assert
not
hasattr
(
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
setattr
(
weight
,
key
,
value
)
setattr
(
weight
,
key
,
value
)
def
pad_weight
(
weight
:
torch
.
Tensor
,
num_pad
:
int
,
pad_dim
:
int
=
0
):
if
weight
.
dim
()
==
1
:
padding
=
torch
.
zeros
(
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
weight
.
dim
()
==
2
:
if
pad_dim
==
0
:
padding
=
torch
.
zeros
(
num_pad
,
weight
.
shape
[
1
],
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
0
)
elif
pad_dim
==
1
:
padding
=
torch
.
zeros
(
weight
.
shape
[
0
],
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padded_weight
=
torch
.
cat
([
weight
,
padding
],
dim
=
1
)
else
:
raise
ValueError
(
"pad_dim must be 0 or 1"
)
else
:
raise
ValueError
(
"Weight tensor must be 1D or 2D"
)
padded_weight
=
padded_weight
.
contiguous
()
return
padded_weight
def
gemm_bank_conf
(
weight
):
is_mul_of_2048
=
weight
%
2048
==
0
is_power_of_two
=
(
weight
&
(
weight
-
1
))
==
0
and
weight
!=
0
if
is_mul_of_2048
and
is_power_of_two
:
return
True
else
:
return
False
\ No newline at end of file
vllm/worker/model_runner.py
View file @
7462218e
...
@@ -815,6 +815,34 @@ class ModelRunner:
...
@@ -815,6 +815,34 @@ class ModelRunner:
max_num_seqs
=
min
(
max_num_seqs
=
min
(
max_num_seqs
,
max_num_seqs
,
int
(
max_num_batched_tokens
/
vlm_config
.
image_feature_size
))
int
(
max_num_batched_tokens
/
vlm_config
.
image_feature_size
))
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
for
group_id
in
range
(
1
):
if
max_num_batched_tokens
>=
8000
:
seq_len
=
8000
else
:
seq_len
=
max_num_batched_tokens
if
vlm_config
is
None
:
seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_multi_modal_data
=
None
else
:
seq_data
,
dummy_multi_modal_data
=
MULTIMODAL_REGISTRY
\
.
dummy_data_for_profiling
(
seq_len
,
model_config
,
vlm_config
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
dummy_multi_modal_data
,
)
seqs
.
append
(
seq
)
max_num_batched_tokens
-=
seq_len
for
group_id
in
range
(
max_num_seqs
):
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment