Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8824ae6a
Commit
8824ae6a
authored
Sep 18, 2025
by
王敏
Browse files
merge 092-dev分支近期修改
parents
f9f1887d
c0707728
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
275 deletions
+305
-275
csrc/ops.h
csrc/ops.h
+2
-2
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-1
setup.py
setup.py
+3
-3
tests/engine/test_computed_prefix_blocks.py
tests/engine/test_computed_prefix_blocks.py
+1
-2
tests/engine/test_executor.py
tests/engine/test_executor.py
+3
-4
tests/engine/test_multiproc_workers.py
tests/engine/test_multiproc_workers.py
+18
-17
tests/engine/test_stop_strings.py
tests/engine/test_stop_strings.py
+0
-167
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+36
-17
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+7
-4
vllm/attention/layer.py
vllm/attention/layer.py
+9
-3
vllm/attention/ops/merge_attn_states.py
vllm/attention/ops/merge_attn_states.py
+2
-1
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+15
-3
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+0
-4
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+33
-12
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+146
-23
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+2
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+7
-3
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+10
-5
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+8
-1
No files found.
csrc/ops.h
View file @
8824ae6a
...
@@ -171,8 +171,6 @@ void paged_attention_v2_opt_tc_with_mask(
...
@@ -171,8 +171,6 @@ void paged_attention_v2_opt_tc_with_mask(
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
c10
::
optional
<
torch
::
Tensor
>&
attn_masks
,
const
int64_t
attn_masks_stride
=
0
);
const
int64_t
attn_masks_stride
=
0
);
#ifndef USE_ROCM
void
merge_attn_states
(
torch
::
Tensor
&
output
,
void
merge_attn_states
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_output
,
...
@@ -180,6 +178,8 @@ void merge_attn_states(torch::Tensor& output,
...
@@ -180,6 +178,8 @@ void merge_attn_states(torch::Tensor& output,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
const
torch
::
Tensor
&
suffix_lse
);
#ifndef USE_ROCM
void
convert_vertical_slash_indexes
(
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...
...
csrc/torch_bindings.cpp
View file @
8824ae6a
...
@@ -216,7 +216,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -216,7 +216,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int attn_masks_stride) -> ()"
);
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt_tc_with_mask"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc_with_mask
);
ops
.
impl
(
"paged_attention_v2_opt_tc_with_mask"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc_with_mask
);
#ifndef USE_ROCM
// Merge attn states
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
// can be used to combine partial attention results (in the split-KV case)
...
@@ -230,6 +229,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -230,6 +229,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"
);
" Tensor suffix_lse) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
#ifndef USE_ROCM
ops
.
def
(
ops
.
def
(
"convert_vertical_slash_indexes("
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! block_count, Tensor! block_offset, "
...
...
setup.py
View file @
8824ae6a
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
is
None
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
sha
=
get_sha
(
vllm_root
)
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt1.rc
1
.'
+
sha
[:
7
]
version
=
'das.opt1.rc
2
.'
+
sha
[:
7
]
else
:
else
:
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt1.rc
1
'
version
=
'das.opt1.rc
2
'
# dtk version
# dtk version
...
...
tests/engine/test_computed_prefix_blocks.py
View file @
8824ae6a
...
@@ -8,12 +8,11 @@ from vllm.engine.arg_utils import EngineArgs
...
@@ -8,12 +8,11 @@ from vllm.engine.arg_utils import EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
from
vllm.utils
import
SUPPORT_TC
,
gpuname
import
vllm.envs
as
envs
import
vllm.envs
as
envs
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"distilbert/distilgpt2"
)])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
]
if
gpuname
.
startswith
(
'BW'
)
and
envs
.
VLLM_FLASH_ATTN_
BACKEND
else
[
16
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
]
if
envs
.
VLLM_
USE_
FLASH_ATTN_
PA
else
[
16
])
def
test_computed_prefix_blocks
(
model
:
str
,
block_size
:
int
):
def
test_computed_prefix_blocks
(
model
:
str
,
block_size
:
int
):
# This test checks if we are able to run the engine to completion
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# without triggering asserts.
...
...
tests/engine/test_executor.py
View file @
8824ae6a
...
@@ -14,7 +14,6 @@ from vllm.executor.uniproc_executor import UniProcExecutor
...
@@ -14,7 +14,6 @@ from vllm.executor.uniproc_executor import UniProcExecutor
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
import
os
import
os
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
from
vllm.utils
import
SUPPORT_TC
,
gpuname
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -60,7 +59,7 @@ def test_custom_executor(model, tmp_path):
...
@@ -60,7 +59,7 @@ def test_custom_executor(model, tmp_path):
model
=
model
,
model
=
model
,
distributed_executor_backend
=
CustomUniExecutor
,
distributed_executor_backend
=
CustomUniExecutor
,
enforce_eager
=
True
,
# reduce test time
enforce_eager
=
True
,
# reduce test time
block_size
=
64
if
gpuname
.
startswith
(
'BW'
)
and
envs
.
VLLM_FLASH_ATTN_
BACKEND
else
16
,
block_size
=
64
if
envs
.
VLLM_
USE_
FLASH_ATTN_
PA
else
16
,
)
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
...
@@ -84,7 +83,7 @@ def test_custom_executor_async(model, tmp_path):
...
@@ -84,7 +83,7 @@ def test_custom_executor_async(model, tmp_path):
model
=
model
,
model
=
model
,
distributed_executor_backend
=
CustomUniExecutorAsync
,
distributed_executor_backend
=
CustomUniExecutorAsync
,
enforce_eager
=
True
,
# reduce test time
enforce_eager
=
True
,
# reduce test time
block_size
=
64
if
gpuname
.
startswith
(
'BW'
)
and
envs
.
VLLM_FLASH_ATTN_
BACKEND
else
16
,
block_size
=
64
if
envs
.
VLLM_
USE_
FLASH_ATTN_
PA
else
16
,
)
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
...
@@ -111,7 +110,7 @@ def test_respect_ray(model):
...
@@ -111,7 +110,7 @@ def test_respect_ray(model):
model
=
model
,
model
=
model
,
distributed_executor_backend
=
"ray"
,
distributed_executor_backend
=
"ray"
,
enforce_eager
=
True
,
# reduce test time
enforce_eager
=
True
,
# reduce test time
block_size
=
64
if
gpuname
.
startswith
(
'BW'
)
and
envs
.
VLLM_FLASH_ATTN_
BACKEND
else
16
,
block_size
=
64
if
envs
.
VLLM_
USE_
FLASH_ATTN_
PA
else
16
,
)
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
assert
engine
.
model_executor
.
uses_ray
assert
engine
.
model_executor
.
uses_ray
\ No newline at end of file
tests/engine/test_multiproc_workers.py
View file @
8824ae6a
...
@@ -100,29 +100,30 @@ def test_local_workers() -> None:
...
@@ -100,29 +100,30 @@ def test_local_workers() -> None:
assert
isinstance
(
e
,
ChildProcessError
)
assert
isinstance
(
e
,
ChildProcessError
)
def
test_local_workers_clean_shutdown
()
->
None
:
# @TODO
"""Test clean shutdown"""
# def test_local_workers_clean_shutdown() -> None:
# """Test clean shutdown"""
workers
,
worker_monitor
=
_start_workers
()
#
workers, worker_monitor = _start_workers()
assert
worker_monitor
.
is_alive
()
#
assert worker_monitor.is_alive()
assert
all
(
worker
.
process
.
is_alive
()
for
worker
in
workers
)
#
assert all(worker.process.is_alive() for worker in workers)
# Clean shutdown
#
# Clean shutdown
worker_monitor
.
close
()
#
worker_monitor.close()
worker_monitor
.
join
(
20
)
#
worker_monitor.join(20)
# Ensure everything is stopped
#
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
#
assert not worker_monitor.is_alive()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
#
assert all(not worker.process.is_alive() for worker in workers)
# Further attempts to submit tasks should fail
#
# Further attempts to submit tasks should fail
try
:
#
try:
_result
=
workers
[
0
].
execute_method
(
"worker_method"
,
"test"
)
#
_result = workers[0].execute_method("worker_method", "test")
pytest
.
fail
(
"task should fail once workers have been shut down"
)
#
pytest.fail("task should fail once workers have been shut down")
except
Exception
as
e
:
#
except Exception as e:
assert
isinstance
(
e
,
ChildProcessError
)
#
assert isinstance(e, ChildProcessError)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
tests/engine/test_stop_strings.py
deleted
100644 → 0
View file @
f9f1887d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
List
,
Optional
import
pytest
import
os
from
vllm
import
CompletionOutput
,
LLMEngine
,
SamplingParams
from
..utils
import
models_path_prefix
MODEL
=
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/llama-2-7b-hf"
)
MAX_TOKENS
=
200
IS_ASYNC
=
False
@
pytest
.
fixture
(
scope
=
"session"
)
def
vllm_model
(
vllm_runner
):
with
vllm_runner
(
MODEL
)
as
vllm_model
:
yield
vllm_model
def
_test_stopping
(
llm_engine
:
LLMEngine
,
expected_output
:
str
,
expected_reason
:
Any
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_in_output
:
bool
=
False
,
use_async_output_proc
:
bool
=
False
)
->
None
:
llm_engine
.
add_request
(
"id"
,
"A story about vLLM:
\n
"
,
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
MAX_TOKENS
,
stop
=
stop
,
stop_token_ids
=
stop_token_ids
,
include_stop_str_in_output
=
include_in_output
,
),
None
)
output
:
Optional
[
CompletionOutput
]
=
None
output_text
=
""
stop_reason
=
None
if
use_async_output_proc
:
llm_engine
.
step
()
while
llm_engine
.
has_unfinished_requests
():
(
request_output
,
)
=
llm_engine
.
step
()
(
output
,
)
=
request_output
.
outputs
# Ensure we don't backtrack
assert
output
.
text
.
startswith
(
output_text
)
output_text
=
output
.
text
stop_reason
=
output
.
stop_reason
assert
output
is
not
None
assert
output_text
==
expected_output
assert
stop_reason
==
expected_reason
def
_set_async_mode
(
llm_engine
,
is_async
):
llm_engine
.
scheduler
[
0
].
use_async_output_proc
=
is_async
def
_stop_basic
(
llm_engine
,
is_async
):
_test_stopping
(
llm_engine
,
stop
=
[
"."
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_reason
=
"."
,
use_async_output_proc
=
is_async
)
_test_stopping
(
llm_engine
,
stop
=
[
"."
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization."
,
expected_reason
=
"."
,
use_async_output_proc
=
is_async
)
def
_stop_multi_tokens
(
llm_engine
,
is_async
):
_test_stopping
(
llm_engine
,
stop
=
[
"group of peo"
,
"short"
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer organization. We are a "
,
expected_reason
=
"group of peo"
,
use_async_output_proc
=
is_async
)
_test_stopping
(
llm_engine
,
stop
=
[
"group of peo"
,
"short"
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization. We are a group of peo"
,
expected_reason
=
"group of peo"
,
use_async_output_proc
=
is_async
)
def
_stop_partial_token
(
llm_engine
,
is_async
):
_test_stopping
(
llm_engine
,
stop
=
[
"gani"
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer or"
,
expected_reason
=
"gani"
,
use_async_output_proc
=
is_async
)
_test_stopping
(
llm_engine
,
stop
=
[
"gani"
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organi"
,
expected_reason
=
"gani"
,
use_async_output_proc
=
is_async
)
def
_stop_token_id
(
llm_engine
,
is_async
):
# token id 13013 => " organization"
_test_stopping
(
llm_engine
,
stop_token_ids
=
[
13013
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer"
,
expected_reason
=
13013
,
use_async_output_proc
=
is_async
)
_test_stopping
(
llm_engine
,
stop_token_ids
=
[
13013
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_reason
=
13013
,
use_async_output_proc
=
is_async
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_basic
(
vllm_model
):
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_basic
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_basic
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_multi_tokens
(
vllm_model
):
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_multi_tokens
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_multi_tokens
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_partial_token
(
vllm_model
):
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_partial_token
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_partial_token
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_token_id
(
vllm_model
):
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_token_id
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_token_id
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
vllm/attention/backends/dual_chunk_flash_attn.py
View file @
8824ae6a
...
@@ -1221,6 +1221,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1221,6 +1221,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
s_lse
=
s_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
s_lse
=
s_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
res
,
s_lse
return
res
,
s_lse
if
not
current_platform
.
is_rocm
():
output
,
softmax_lse
=
flash_attn_varlen_func
(
output
,
softmax_lse
=
flash_attn_varlen_func
(
q
=
query_states
,
q
=
query_states
,
k
=
key_states
,
k
=
key_states
,
...
@@ -1238,6 +1239,24 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1238,6 +1239,24 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
block_table
=
block_table
.
unsqueeze
(
0
),
block_table
=
block_table
.
unsqueeze
(
0
),
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
)
)
else
:
output
,
softmax_lse
=
flash_attn_varlen_func
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
softmax_scale
=
softmax_scale
,
cu_seqlens_q
=
torch
.
tensor
([
0
,
query_states
.
shape
[
0
]],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
),
max_seqlen_q
=
query_states
.
shape
[
0
],
cu_seqlens_k
=
torch
.
tensor
([
0
,
max_seqlen_k
],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
),
max_seqlen_k
=
max_seqlen_k
,
causal
=
causal
,
block_table
=
block_table
.
unsqueeze
(
0
),
return_attn_probs
=
True
,
)
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
2
).
float
()
return
output
,
softmax_lse
return
output
,
softmax_lse
...
...
vllm/attention/backends/mla/common.py
View file @
8824ae6a
...
@@ -1043,9 +1043,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1043,9 +1043,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# v with 0s to match the qk head dim for attention backends that do
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
or
not
(
if
not
current_platform
.
is_rocm
():
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
else
:
self
.
_pad_v
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
return_softmax_lse
,
**
kwargs
):
...
...
vllm/attention/layer.py
View file @
8824ae6a
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_maybe_save_kv_layer_to_connector
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
...
@@ -412,6 +413,9 @@ def unified_attention(
...
@@ -412,6 +413,9 @@ def unified_attention(
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
attn_metadata
)
if
envs
.
VLLM_ENABLE_TBO
:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
else
:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
return
output
return
output
...
@@ -457,7 +461,9 @@ def unified_attention_with_output(
...
@@ -457,7 +461,9 @@ def unified_attention_with_output(
attn_metadata
,
attn_metadata
,
output
=
output
,
output
=
output
,
output_scale
=
output_scale
)
output_scale
=
output_scale
)
if
envs
.
VLLM_ENABLE_TBO
:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
else
:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
...
...
vllm/attention/ops/merge_attn_states.py
View file @
8824ae6a
...
@@ -5,6 +5,7 @@ from typing import Optional
...
@@ -5,6 +5,7 @@ from typing import Optional
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm
import
envs
def
merge_attn_states
(
def
merge_attn_states
(
...
@@ -31,7 +32,7 @@ def merge_attn_states(
...
@@ -31,7 +32,7 @@ def merge_attn_states(
return
headdim
%
4
==
0
return
headdim
%
4
==
0
return
headdim
%
8
==
0
return
headdim
%
8
==
0
if
(
current_platform
.
is_cuda
()
and
supported_dtypes
(
output
)
if
(
current_platform
.
is_cuda
()
or
envs
.
VLLM_USE_MERGE_ATTN_STATES_OPT
and
supported_dtypes
(
output
)
and
supported_headdim
(
output
)):
and
supported_headdim
(
output
)):
from
vllm._custom_ops
import
merge_attn_states
from
vllm._custom_ops
import
merge_attn_states
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
8824ae6a
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
regex
as
re
import
regex
as
re
import
torch
import
torch
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
...
@@ -215,9 +216,18 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -215,9 +216,18 @@ class P2pNcclConnector(KVConnectorBase_V1):
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
,
request
.
request_id
)
request
.
slot_mapping
,
request
.
request_id
)
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
request
.
request_id
+
"#"
+
layer_name
,
None
)
tensor_id
=
request
.
request_id
+
"#"
+
layer_name
if
tensor
is
not
None
:
if
tensor_id
in
self
.
p2p_nccl_engine
.
recv_store
:
del
tensor
tensor
=
self
.
p2p_nccl_engine
.
recv_store
.
pop
(
tensor_id
,
None
)
self
.
p2p_nccl_engine
.
send_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
self
.
p2p_nccl_engine
.
recv_request_id_to_tensor_ids
.
pop
(
request
.
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
p2p_nccl_engine
.
pool
.
free
(
addr
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""Blocking until the KV for a specific layer is loaded into vLLM's
"""Blocking until the KV for a specific layer is loaded into vLLM's
...
@@ -258,6 +268,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -258,6 +268,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
"""
if
envs
.
VLLM_ENABLE_TBO
:
slot_mapping
=
slot_mapping
.
pin_memory
().
to
(
device
=
layer
.
device
,
non_blocking
=
True
)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
8824ae6a
...
@@ -326,10 +326,6 @@ class P2pNcclEngine:
...
@@ -326,10 +326,6 @@ class P2pNcclEngine:
# Store Tensor in memory pool
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
,
addr
)
else
:
else
:
self
.
buffer_size
+=
tensor_size
self
.
buffer_size
+=
tensor_size
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
8824ae6a
...
@@ -7,6 +7,7 @@ from collections.abc import Iterable
...
@@ -7,6 +7,7 @@ from collections.abc import Iterable
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
...
@@ -259,7 +260,7 @@ class EPMoE(FusedMoE):
...
@@ -259,7 +260,7 @@ class EPMoE(FusedMoE):
hidden_dim
=
self
.
hidden_size
,
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
0
,
scale_dim
=
0
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
,
max_num_inp_token_per_rank
=
5120
,
max_num_inp_token_per_rank
=
4096
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
2
,
max_token_type_size
=
2
,
...
@@ -294,7 +295,9 @@ class EPMoE(FusedMoE):
...
@@ -294,7 +295,9 @@ class EPMoE(FusedMoE):
dist
.
barrier
()
dist
.
barrier
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
):
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
self
.
layer_name
)
...
@@ -318,7 +321,10 @@ class EPMoE(FusedMoE):
...
@@ -318,7 +321,10 @@ class EPMoE(FusedMoE):
]
]
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
):
topk_weights
,
topk_ids
=
self
.
select_experts
(
topk_weights
,
topk_ids
=
self
.
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -334,7 +340,11 @@ class EPMoE(FusedMoE):
...
@@ -334,7 +340,11 @@ class EPMoE(FusedMoE):
indices_type
=
torch
.
int64
,
indices_type
=
torch
.
int64
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
use_fused_gate
=
self
.
use_fused_gate
)
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
...
@@ -345,6 +355,8 @@ class EPMoE(FusedMoE):
...
@@ -345,6 +355,8 @@ class EPMoE(FusedMoE):
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
)
)
#dist.barrier()
(
(
dispatch_output
,
dispatch_output
,
dispatch_weights
,
dispatch_weights
,
...
@@ -360,13 +372,17 @@ class EPMoE(FusedMoE):
...
@@ -360,13 +372,17 @@ class EPMoE(FusedMoE):
#self.sync()
#self.sync()
#dispatch_recv_num_token = dispatch_recv_num_token[0].item()
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# #dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = dispatch_output[:dispatch_recv_num_token]
# dispatch_output = dispatch_output[:dispatch_recv_num_token]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
# dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = torch.narrow(dispatch_output, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_weights = torch.narrow(dispatch_weights, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_indices = torch.narrow(dispatch_indices, dim=0, start=0, length=dispatch_recv_num_token)
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
# dispatch_output = dispatch_output[valid_mask]
# dispatch_output = dispatch_output[valid_mask]
...
@@ -418,26 +434,31 @@ class EPMoE(FusedMoE):
...
@@ -418,26 +434,31 @@ class EPMoE(FusedMoE):
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
return
final_hidden_states
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
,
new_resi
else
:
return
final_hidden_states
,
None
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
rms_weight
,
residual
)
def
ep_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
ep_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
return
torch
.
empty_like
(
hidden_states
)
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
hidden_states
),
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"ep_moe_forward"
,
op_name
=
"ep_moe_forward"
,
op_func
=
ep_moe_forward
,
op_func
=
ep_moe_forward
,
mutates_args
=
[
"hidden_states"
],
mutates_args
=
[
"hidden_states"
,
"router_logits"
,
"rms_weight"
,
"residual"
],
fake_impl
=
ep_moe_forward_fake
,
fake_impl
=
ep_moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
...
...
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
8824ae6a
...
@@ -234,7 +234,7 @@ def moe_align_block_size(
...
@@ -234,7 +234,7 @@ def moe_align_block_size(
if
envs
.
VLLM_USE_LIGHT_OP
:
if
envs
.
VLLM_USE_LIGHT_OP
:
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
,
num_tokens_post_pad
,
None
)
else
:
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
,
num_tokens_post_pad
)
...
...
vllm/model_executor/layers/linear.py
View file @
8824ae6a
...
@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
...
@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
import
os
import
os
from
vllm.model_executor.utils
import
gemm_bank_conf
from
vllm.model_executor.utils
import
gemm_bank_conf
if
envs
.
USE_FUSED_RMS_QUANT
:
try
:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
...
@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
...
@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
quant_config
,
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
)
self
.
eps
=
eps
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -385,11 +393,49 @@ class ReplicatedLinear(LinearBase):
...
@@ -385,11 +393,49 @@ class ReplicatedLinear(LinearBase):
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
def
forward
(
self
,
x
:
torch
.
Tensor
self
,
input_
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_args
:
Optional
[
list
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
(
rms_weight
is
not
None
or
quant_args
is
not
None
):
if
quant_args
is
not
None
:
input_quant_args
=
quant_args
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
else
:
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
,
input_quant_args
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
if
not
self
.
return_bias
:
return
output
return
output
...
@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
...
@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase):
quant_config
,
quant_config
,
prefix
,
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
)
self
.
eps
=
eps
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
if
output_sizes
is
None
:
if
output_sizes
is
None
:
...
@@ -543,10 +590,37 @@ class ColumnParallelLinear(LinearBase):
...
@@ -543,10 +590,37 @@ class ColumnParallelLinear(LinearBase):
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
def
forward
(
self
,
input_
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
assert
rms_weight
is
not
None
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
if
self
.
gather_output
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
# Matrix multiply.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
...
@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return_bias: If true, return bias together with outputs in forward pass.
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
forward
(
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
assert
residual
is
not
None
and
rms_weight
is
not
None
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
else
:
# not USE_FUSED_RMS_QUANT
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
...
@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
):
):
self
.
eps
=
eps
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
shard_offset
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_size
=
shard_size
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
"""Linear layers for the attention's QKV transformation.
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
8824ae6a
...
@@ -130,7 +130,7 @@ class AWQConfig(QuantizationConfig):
...
@@ -130,7 +130,7 @@ class AWQConfig(QuantizationConfig):
return
"awq"
return
"awq"
def
get_supported_act_dtypes
(
self
)
->
list
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
self
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -293,7 +293,7 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -293,7 +293,7 @@ class AWQLinearMethod(LinearMethodBase):
pad_group
=
2
pad_group
=
2
dim_n
=
layer
.
scales
.
data
.
shape
[
1
]
dim_n
=
layer
.
scales
.
data
.
shape
[
1
]
dim_k
=
layer
.
qweight
.
data
.
shape
[
0
]
dim_k
=
layer
.
qweight
.
data
.
shape
[
0
]
_qw
,
_sz
=
ops
.
convert_s4
(
layer
.
qweight
,
layer
.
qzeros
,
layer
.
scales
,
int
(
group_size
))
_qw
,
_sz
=
ops
.
convert_s4
(
layer
.
qweight
,
layer
.
qzeros
,
layer
.
scales
.
to
(
torch
.
float16
)
,
int
(
group_size
))
sz
=
ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
sz
=
ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
sz
=
sz
.
reshape
(
dim_n
,
-
1
)
sz
=
sz
.
reshape
(
dim_n
,
-
1
)
_qw
=
_qw
.
reshape
(
dim_n
,
-
1
)
_qw
=
_qw
.
reshape
(
dim_n
,
-
1
)
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
8824ae6a
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
UnquantizedFusedMoEMethod
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
UnquantizedLinearMethod
,
set_weight_attrs
)
set_weight_attrs
)
...
@@ -140,6 +141,9 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -140,6 +141,9 @@ class AWQMarlinConfig(QuantizationConfig):
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMarlinLinearMethod
(
self
)
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
is_layer_skipped_awq
(
prefix
,
getattr
(
self
,
"modules_to_not_convert"
,
[])):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
MoeWNA16Config
)
MoeWNA16Config
)
if
not
check_moe_marlin_supports_layer
(
layer
,
self
.
group_size
):
if
not
check_moe_marlin_supports_layer
(
layer
,
self
.
group_size
):
...
@@ -436,7 +440,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -436,7 +440,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
# Why does this take the intermediate size for size_k?
# Why does this take the intermediate size for size_k?
marlin_w13_scales
=
marlin_moe_permute_scales
(
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
s
=
layer
.
w13_scales
.
to
(
torch
.
float16
)
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
...
@@ -445,7 +449,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -445,7 +449,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
#replace_parameter(layer, "w13_scales", marlin_w13_scales)
#replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales
=
marlin_moe_permute_scales
(
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
s
=
layer
.
w2_scales
.
to
(
torch
.
float16
)
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
8824ae6a
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
import
os
import
os
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
UnquantizedFusedMoEMethod
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
@@ -18,7 +19,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -18,7 +19,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.awq
import
(
is_layer_skipped_awq
)
from
lmslim.layers.fused_moe.fuse_moe_int4
import
fused_experts_w4a16
from
lmslim.layers.fused_moe.fuse_moe_int4
import
fused_experts_w4a16
os
.
environ
[
'W4A16_MOE_CUDA'
]
=
os
.
environ
.
get
(
'W4A16_MOE_CUDA'
,
'0'
)
os
.
environ
[
'W4A16_MOE_CUDA'
]
=
os
.
environ
.
get
(
'W4A16_MOE_CUDA'
,
'0'
)
...
@@ -139,9 +141,9 @@ class MoeWNA16Config(QuantizationConfig):
...
@@ -139,9 +141,9 @@ class MoeWNA16Config(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_quant
(
prefix
,
self
.
modules_to_not_convert
):
if
is_layer_skipped_quant
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
LinearBase
):
# Avoid circular import
# Avoid circular import
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
...
@@ -167,6 +169,9 @@ class MoeWNA16Config(QuantizationConfig):
...
@@ -167,6 +169,9 @@ class MoeWNA16Config(QuantizationConfig):
else
:
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
is_layer_skipped_awq
(
prefix
,
getattr
(
self
,
"modules_to_not_convert"
,
[])):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
return
MoeWNA16Method
(
self
)
return
MoeWNA16Method
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
8824ae6a
...
@@ -21,6 +21,8 @@ from vllm.utils import W8a8GetCacheJSON
...
@@ -21,6 +21,8 @@ from vllm.utils import W8a8GetCacheJSON
import
os
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8
import
fused_experts_impl_w4a8_ep
from
lmslim.layers.fused_moe.fuse_moe_w4a8
import
fused_experts_impl_w4a8_ep
except
Exception
:
except
Exception
:
...
@@ -156,7 +158,12 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -156,7 +158,12 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
):
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
if
self
.
w8a8_strategy
==
1
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment