Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f44e9f9e
Commit
f44e9f9e
authored
Apr 18, 2025
by
zhuwenwen
Browse files
Merge branch 'v0.7.2-dev' into v0.7.2-fusion
parents
525d9d7e
8fc15e04
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
518 additions
and
99 deletions
+518
-99
README.md
README.md
+4
-1
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+2
-0
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+5
-4
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+51
-24
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+163
-7
examples/mla/test_triton_decode_attention.py
examples/mla/test_triton_decode_attention.py
+1
-2
examples/mla/triton_decode_attention.py
examples/mla/triton_decode_attention.py
+1
-2
setup.py
setup.py
+2
-2
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+8
-3
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+4
-5
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+4
-1
vllm/config.py
vllm/config.py
+1
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+16
-10
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+0
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+181
-4
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+7
-5
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+5
-0
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
.../openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
+52
-26
vllm/envs.py
vllm/envs.py
+10
-0
No files found.
README.md
View file @
f44e9f9e
...
@@ -12,7 +12,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
...
@@ -12,7 +12,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| :------: | :------: | :------: | :------: |:------: |
| :------: | :------: | :------: | :------: |:------: |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen
,gte_Qwen2-1.5B-instruct
| Yes | Yes | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes |
| DeepseekForCausalLM | Deepseek | Yes | No | - |
| DeepseekForCausalLM | Deepseek | Yes | No | - |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - |
...
@@ -31,6 +31,9 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
...
@@ -31,6 +31,9 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | No | Yes |
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | No | Yes |
| MiniCPMV | MiniCPM-V | Yes | No | - |
| MiniCPMV | MiniCPM-V | Yes | No | - |
| Phi3VForCausalLM | Phi-3.5-vision | Yes | No | - |
| Phi3VForCausalLM | Phi-3.5-vision | Yes | No | - |
| BertModel | bge-large-zh-v1.5 | Yes | No | - |
| XLMRobertaModel | bge-m3 | Yes | No | - |
| XLMRobertaForSequenceClassification | bge-reranker-v2-m3 | Yes | No | - |
## 安装
## 安装
...
...
benchmarks/benchmark_serving.py
View file @
f44e9f9e
...
@@ -570,6 +570,8 @@ async def benchmark(
...
@@ -570,6 +570,8 @@ async def benchmark(
else
:
else
:
print
(
"Initial test run completed. Starting main benchmark run..."
)
print
(
"Initial test run completed. Starting main benchmark run..."
)
time
.
sleep
(
0.1
)
# ZERO_OVERHEAD : sleep and wait the last step in warmup
if
profile
:
if
profile
:
print
(
"Starting profiler..."
)
print
(
"Starting profiler..."
)
profile_input
=
RequestFuncInput
(
model
=
model_id
,
profile_input
=
RequestFuncInput
(
model
=
model_id
,
...
...
benchmarks/benchmark_throughput.py
View file @
f44e9f9e
...
@@ -8,7 +8,7 @@ import time
...
@@ -8,7 +8,7 @@ import time
from
pathlib
import
Path
from
pathlib
import
Path
from
functools
import
cache
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
os
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
uvloop
import
uvloop
...
@@ -180,7 +180,7 @@ def run_vllm(
...
@@ -180,7 +180,7 @@ def run_vllm(
sampling_params
:
List
[
SamplingParams
]
=
[]
sampling_params
:
List
[
SamplingParams
]
=
[]
for
request
in
requests
:
for
request
in
requests
:
prompts
.
append
(
prompts
.
append
(
TextPrompt
(
prompt
=
request
.
prompt
,
TextPrompt
(
prompt
=
"helloword"
,
multi_modal_data
=
request
.
multi_modal_data
))
multi_modal_data
=
request
.
multi_modal_data
))
sampling_params
.
append
(
sampling_params
.
append
(
SamplingParams
(
SamplingParams
(
...
@@ -206,15 +206,16 @@ def run_vllm(
...
@@ -206,15 +206,16 @@ def run_vllm(
dummy_prompts
:
List
[
PromptType
]
=
[{
dummy_prompts
:
List
[
PromptType
]
=
[{
"prompt_token_ids"
:
batch
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
print
(
f
'
{
os
.
environ
.
get
(
"VLLM_ZERO_OVERHEAD"
)
==
"1"
}
'
)
print
(
"Warming up..."
)
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
num_iters_warmup
),
desc
=
"Warmup iterations"
):
for
_
in
tqdm
(
range
(
num_iters_warmup
),
desc
=
"Warmup iterations"
):
llm
.
generate
(
dummy_prompts
,
llm
.
generate
(
dummy_prompts
,
sampling_params
=
warmup_sampling_params
,
sampling_params
=
warmup_sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
use_beam_search
=
False
use_beam_search
=
False
print
(
"testing"
)
if
not
use_beam_search
:
if
not
use_beam_search
:
if
args
.
profile
:
if
args
.
profile
:
profile_dir
=
args
.
profile_result_dir
profile_dir
=
args
.
profile_result_dir
...
...
csrc/custom_all_reduce.cu
View file @
f44e9f9e
...
@@ -14,14 +14,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
...
@@ -14,14 +14,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
)
{
bool
full_nvlink
)
{
int
world_size
=
fake_ipc_ptrs
.
size
();
int
world_size
=
fake_ipc_ptrs
.
size
();
if
(
world_size
>
8
)
if
(
world_size
>
16
)
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
if
(
world_size
%
2
!=
0
)
if
(
world_size
%
2
!=
0
)
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
if
(
rank
<
0
||
rank
>=
world_size
)
if
(
rank
<
0
||
rank
>=
world_size
)
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
vllm
::
Signal
*
ipc_ptrs
[
8
];
vllm
::
Signal
*
ipc_ptrs
[
16
];
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
vllm
::
Signal
*>
(
fake_ipc_ptrs
[
i
]);
ipc_ptrs
[
i
]
=
reinterpret_cast
<
vllm
::
Signal
*>
(
fake_ipc_ptrs
[
i
]);
}
}
...
@@ -78,29 +78,56 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
...
@@ -78,29 +78,56 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
}
else
{
}
else
{
reg_buffer
=
inp
.
data_ptr
();
reg_buffer
=
inp
.
data_ptr
();
}
}
switch
(
out
.
scalar_type
())
{
if
(
fa
->
full_nvlink_
)
{
case
at
::
ScalarType
::
Float
:
{
switch
(
out
.
scalar_type
())
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
case
at
::
ScalarType
::
Float
:
{
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
out
.
numel
());
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
break
;
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
case
at
::
ScalarType
::
Half
:
{
}
else
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
switch
(
out
.
scalar_type
())
{
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
case
at
::
ScalarType
::
Float
:
{
break
;
fa
->
allreduce_pcie
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
}
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
out
.
numel
());
case
at
::
ScalarType
::
BFloat16
:
{
break
;
fa
->
allreduce
<
nv_bfloat16
>
(
}
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
case
at
::
ScalarType
::
Half
:
{
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
fa
->
allreduce_pcie
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
break
;
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce_pcie
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
#endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
}
}
...
@@ -113,7 +140,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
...
@@ -113,7 +140,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
void
*
ipc_ptrs
[
8
];
void
*
ipc_ptrs
[
16
];
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
fake_ipc_ptrs
[
i
]);
ipc_ptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
fake_ipc_ptrs
[
i
]);
}
}
...
...
csrc/custom_all_reduce.cuh
View file @
f44e9f9e
...
@@ -52,17 +52,17 @@ using FlagType = uint32_t;
...
@@ -52,17 +52,17 @@ using FlagType = uint32_t;
// waiting for counter. We use alternating counter array to avoid this
// waiting for counter. We use alternating counter array to avoid this
// possibility.
// possibility.
struct
Signal
{
struct
Signal
{
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
16
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
16
];
alignas
(
128
)
FlagType
_flag
[
kMaxBlocks
];
// incremental flags for each rank
alignas
(
128
)
FlagType
_flag
[
kMaxBlocks
];
// incremental flags for each rank
};
};
struct
__align__
(
16
)
RankData
{
struct
__align__
(
16
)
RankData
{
const
void
*
ptrs
[
8
];
const
void
*
ptrs
[
16
];
};
};
struct
__align__
(
16
)
RankSignals
{
struct
__align__
(
16
)
RankSignals
{
Signal
*
signals
[
8
];
Signal
*
signals
[
16
];
};
};
// like std::array, but aligned
// like std::array, but aligned
...
@@ -104,7 +104,7 @@ DINLINE half& assign_add(half& a, half b) {
...
@@ -104,7 +104,7 @@ DINLINE half& assign_add(half& a, half b) {
}
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
//
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
template
<
>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
...
@@ -114,7 +114,7 @@ DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
...
@@ -114,7 +114,7 @@ DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a
=
__hadd
(
a
,
b
);
a
=
__hadd
(
a
,
b
);
return
a
;
return
a
;
}
}
#endif
//
#endif
template
<
typename
T
,
int
N
>
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
...
@@ -373,6 +373,84 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -373,6 +373,84 @@ __global__ void __launch_bounds__(512, 1)
}
}
}
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage_pcie
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
,
uint32_t
**
curr_hdp_reg
,
int
world_size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
if
(
threadIdx
.
x
==
1
)
{
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
__atomic_store_n
(
curr_hdp_reg
[
i
],
0x1
,
__ATOMIC_RELAXED
);
}
}
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage_pcie
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
,
uint32_t
**
curr_hdp_reg
,
int
world_size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
int
part
=
size
/
ngpus
;
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
if
(
threadIdx
.
x
==
1
)
{
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
__atomic_store_n
(
curr_hdp_reg
[
i
],
0x1
,
__ATOMIC_RELAXED
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from
// all ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
}
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
static_assert
(
sizeof
(
IPC_KEY
)
==
sizeof
(
cudaIpcMemHandle_t
));
static_assert
(
sizeof
(
IPC_KEY
)
==
sizeof
(
cudaIpcMemHandle_t
));
static_assert
(
alignof
(
IPC_KEY
)
==
alignof
(
cudaIpcMemHandle_t
));
static_assert
(
alignof
(
IPC_KEY
)
==
alignof
(
cudaIpcMemHandle_t
));
...
@@ -409,6 +487,7 @@ class CustomAllreduce {
...
@@ -409,6 +487,7 @@ class CustomAllreduce {
// a map from IPC handles to opened IPC pointers
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
uint32_t
**
dev_curr_hdp_reg
;
/**
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* For each of the buffer, the layout is as follows:
...
@@ -431,6 +510,12 @@ class CustomAllreduce {
...
@@ -431,6 +510,12 @@ class CustomAllreduce {
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
sg_
.
signals
[
i
]
=
signals
[
i
];
sg_
.
signals
[
i
]
=
signals
[
i
];
}
}
if
(
!
full_nvlink
)
{
cudaMalloc
((
void
**
)
&
dev_curr_hdp_reg
,
world_size_
*
sizeof
(
uint32_t
*
));
for
(
int
i
=
0
;
i
<
world_size_
;
++
i
)
{
hipDeviceGetAttribute
((
int
*
)
&
dev_curr_hdp_reg
[
i
],
hipDeviceAttributeHdpMemFlushCntl
,
i
);
}
}
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
...
@@ -522,6 +607,75 @@ class CustomAllreduce {
...
@@ -522,6 +607,75 @@ class CustomAllreduce {
graph_unreg_buffers_
.
clear
();
graph_unreg_buffers_
.
clear
();
}
}
template
<
typename
T
>
void
allreduce_pcie
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
ptrs
=
d_rank_data_base_
+
graph_unreg_buffers_
.
size
();
graph_unreg_buffers_
.
push_back
(
input
);
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
std
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
if ((world_size_ <= 4 && bytes < 128 * 8192) || \
(world_size_ <= 8 && bytes < 8 * 8192)) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
KL(ngpus, cross_device_reduce_2stage_pcie); \
} \
} \
break; \
}
switch
(
world_size_
)
{
REDUCE_CASE
(
2
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
REDUCE_CASE
(
16
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8,16). Actual "
"num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
#undef REDUCE_CASE
#undef KL
}
/**
/**
* Performs allreduce, assuming input has already been registered.
* Performs allreduce, assuming input has already been registered.
*
*
...
@@ -587,9 +741,10 @@ class CustomAllreduce {
...
@@ -587,9 +741,10 @@ class CustomAllreduce {
REDUCE_CASE
(
4
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
REDUCE_CASE
(
8
)
REDUCE_CASE
(
16
)
default:
default:
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"custom allreduce only supports num gpus in (2,4,6,8
,16
). Actual "
"num "
"num "
"gpus = "
+
"gpus = "
+
std
::
to_string
(
world_size_
));
std
::
to_string
(
world_size_
));
...
@@ -602,6 +757,7 @@ class CustomAllreduce {
...
@@ -602,6 +757,7 @@ class CustomAllreduce {
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CUDACHECK
(
cudaIpcCloseMemHandle
(
ptr
));
CUDACHECK
(
cudaIpcCloseMemHandle
(
ptr
));
}
}
cudaFree
(
dev_curr_hdp_reg
);
}
}
};
};
...
...
examples/mla/test_triton_decode_attention.py
View file @
f44e9f9e
...
@@ -212,5 +212,4 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
...
@@ -212,5 +212,4 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# 保存最佳配置
# 保存最佳配置
with
open
(
file_name
,
'w'
)
as
file
:
with
open
(
file_name
,
'w'
)
as
file
:
json
.
dump
(
config_info
,
file
,
indent
=
1
)
json
.
dump
(
config_info
,
file
,
indent
=
1
)
#**************save config**************#
#**************save config**************#
\ No newline at end of file
examples/mla/triton_decode_attention.py
View file @
f44e9f9e
...
@@ -1237,5 +1237,4 @@ def decode_attentionv1_fwd(
...
@@ -1237,5 +1237,4 @@ def decode_attentionv1_fwd(
page_size
,
page_size
,
logit_cap
,
logit_cap
,
)
)
return
v1_tc_stage1_best_config
,
v1_tc_stage2_best_config
return
v1_tc_stage1_best_config
,
v1_tc_stage2_best_config
\ No newline at end of file
setup.py
View file @
f44e9f9e
...
@@ -488,11 +488,11 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -488,11 +488,11 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
is
None
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
sha
=
get_sha
(
vllm_root
)
if
(
major
,
minor
)
==
(
'2'
,
'4'
):
if
(
major
,
minor
)
==
(
'2'
,
'4'
):
version
=
'das.opt1.
alph
a.'
+
sha
[:
7
]
version
=
'das.opt1.
bet
a.'
+
sha
[:
7
]
# version = 'das.opt1.' + sha[:7]
# version = 'das.opt1.' + sha[:7]
else
:
else
:
if
(
major
,
minor
)
==
(
'2'
,
'4'
):
if
(
major
,
minor
)
==
(
'2'
,
'4'
):
version
=
'das.opt1.
alph
a'
version
=
'das.opt1.
bet
a'
# version = 'das.opt1'
# version = 'das.opt1'
...
...
vllm/attention/backends/mla/utils.py
View file @
f44e9f9e
...
@@ -533,13 +533,16 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -533,13 +533,16 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# For MLA the v head dim is smaller than qk head dim so we pad out
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
(
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
-
32
)],
value
=
0
)
value
=
0
)
v_tmp
=
v_padded
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
attn_output
=
flash_attn_varlen_func
(
attn_output
=
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v_
padded
,
v
=
v_
tmp
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
else
v
,
cu_seqlens_q
=
seq_start_loc
,
cu_seqlens_q
=
seq_start_loc
,
cu_seqlens_k
=
seq_start_loc
,
cu_seqlens_k
=
seq_start_loc
,
max_seqlen_q
=
max_prefill_seq_len
,
max_seqlen_q
=
max_prefill_seq_len
,
...
@@ -547,8 +550,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -547,8 +550,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
attn_output
=
attn_output
\
attn_output
=
attn_output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
attn_output
)[
0
]
return
self
.
o_proj
(
attn_output
)[
0
]
vllm/attention/backends/rocm_flash_attn.py
View file @
f44e9f9e
...
@@ -790,6 +790,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -790,6 +790,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
context_lens_tensor
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
self
.
sliding_window
[
0
],
...
...
vllm/attention/backends/utils.py
View file @
f44e9f9e
...
@@ -14,8 +14,6 @@ from vllm.attention.backends.abstract import AttentionType
...
@@ -14,8 +14,6 @@ from vllm.attention.backends.abstract import AttentionType
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
@@ -235,8 +233,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -235,8 +233,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
# block_tables = torch.from_numpy(input_block_tables).to(
device
,
non_blocking
=
True
)
# device, non_blocking=True)
block_tables
=
async_tensor_h2d
(
input_block_tables
.
tolist
(),
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
else
:
else
:
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
self
.
block_tables
,
...
@@ -245,7 +245,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -245,7 +245,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
=
device
,
device
=
device
,
)
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
assert
device
is
not
None
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
device
,
self
.
runner
.
pin_memory
)
...
...
vllm/benchmarks/benchmark_throughput.py
View file @
f44e9f9e
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
json
import
json
import
os
import
random
import
random
import
time
import
time
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -214,7 +215,9 @@ def run_vllm(
...
@@ -214,7 +215,9 @@ def run_vllm(
use_tqdm
=
False
)
use_tqdm
=
False
)
use_beam_search
=
False
use_beam_search
=
False
if
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
:
print
(
"sleep 1"
)
time
.
sleep
(
1
)
# ZERO_OVERHEAD : sleep and wait the last step in warmup
if
not
use_beam_search
:
if
not
use_beam_search
:
if
args
.
profile
:
if
args
.
profile
:
profile_dir
=
args
.
profile_result_dir
profile_dir
=
args
.
profile_result_dir
...
...
vllm/config.py
View file @
f44e9f9e
...
@@ -423,7 +423,7 @@ class ModelConfig:
...
@@ -423,7 +423,7 @@ class ModelConfig:
self
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
self
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
)
->
Optional
[
"MultiModalConfig"
]:
)
->
Optional
[
"MultiModalConfig"
]:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
if
ModelRegistry
.
is_multimodal_model
(
architectures
):
if
ModelRegistry
.
is_multimodal_model
(
architectures
)
and
hasattr
(
self
.
hf_config
,
"vision_config"
)
:
return
MultiModalConfig
(
limit_per_prompt
=
limit_mm_per_prompt
or
{})
return
MultiModalConfig
(
limit_per_prompt
=
limit_mm_per_prompt
or
{})
if
limit_mm_per_prompt
:
if
limit_mm_per_prompt
:
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
f44e9f9e
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
ctypes
import
ctypes
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
os
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
...
@@ -18,6 +18,7 @@ from vllm.logger import init_logger
...
@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
from
vllm
import
envs
try
:
try
:
ops
.
meta_size
()
ops
.
meta_size
()
custom_ar
=
True
custom_ar
=
True
...
@@ -50,13 +51,13 @@ def is_weak_contiguous(inp: torch.Tensor):
...
@@ -50,13 +51,13 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
,
16
]
# max_size: max supported allreduce size
# max_size: max supported allreduce size
def
__init__
(
self
,
def
__init__
(
self
,
group
:
ProcessGroup
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
1024
*
2
)
->
None
:
max_size
=
8192
*
1024
)
->
None
:
"""
"""
Args:
Args:
group: the process group to work on. If None, it will use the
group: the process group to work on. If None, it will use the
...
@@ -137,11 +138,18 @@ class CustomAllreduce:
...
@@ -137,11 +138,18 @@ class CustomAllreduce:
full_nvlink
=
current_platform
.
is_fully_connected_nvlink_or_xgmi
(
full_nvlink
=
current_platform
.
is_fully_connected_nvlink_or_xgmi
(
physical_device_ids
)
physical_device_ids
)
if
not
full_nvlink
:
if
not
full_nvlink
:
max_size
=
32
*
8192
*
2
if
not
envs
.
VLLM_PCIE_USE_CUSTOM_ALLREDUCE
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
logger
.
warning
(
logger
.
warning
(
"
Custom allreduce is disabled because it's not supported on
"
"
We are using PCIe's custom allreduce.
"
"
more than two PCIe-only GPUs. To silence this warning,
"
"
If the performance is poor, we can add
"
"
specify
disable
_
custom
_
all
_
reduce
=True explicitly
."
)
"
--
disable
-
custom
-
all
-
reduce
in the instruction
."
)
return
# test P2P capability, this checks software/cudaruntime support
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# this is expensive to compute at the first time
# then we cache the result
# then we cache the result
...
@@ -259,9 +267,7 @@ class CustomAllreduce:
...
@@ -259,9 +267,7 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
inp_size
<
self
.
max_size
return
False
def
all_reduce
(
self
,
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
...
...
vllm/engine/async_llm_engine.py
View file @
f44e9f9e
...
@@ -726,7 +726,6 @@ class AsyncLLMEngine(EngineClient):
...
@@ -726,7 +726,6 @@ class AsyncLLMEngine(EngineClient):
"""Kick the engine to process the waiting requests.
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
Returns True if there are in-progress requests."""
new_requests
,
aborted_requests
=
(
new_requests
,
aborted_requests
=
(
self
.
_request_tracker
.
get_new_and_aborted_requests
())
self
.
_request_tracker
.
get_new_and_aborted_requests
())
...
@@ -746,7 +745,6 @@ class AsyncLLMEngine(EngineClient):
...
@@ -746,7 +745,6 @@ class AsyncLLMEngine(EngineClient):
await
self
.
_engine_abort
(
aborted_requests
)
await
self
.
_engine_abort
(
aborted_requests
)
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
# Put the outputs into the corresponding streams.
# Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
# LLMEngine's _process_model_outputs
...
...
vllm/engine/llm_engine.py
View file @
f44e9f9e
...
@@ -3,11 +3,14 @@
...
@@ -3,11 +3,14 @@
import
os
import
os
import
copy
import
copy
import
time
import
time
import
threading
import
queue
from
collections
import
Counter
as
collectionsCounter
from
collections
import
Counter
as
collectionsCounter
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
import
traceback
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -61,6 +64,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
...
@@ -61,6 +64,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
@@ -407,6 +411,19 @@ class LLMEngine:
...
@@ -407,6 +411,19 @@ class LLMEngine:
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
if
self
.
zero_overhead
:
assert
os
.
environ
.
get
(
'HIP_ALLOC_INITIALIZE'
)
==
'0'
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
q_recorder
=
queue
.
Queue
()
self
.
thread_running
=
True
self
.
sem_m2s
=
threading
.
Semaphore
(
0
)
# main to scheduler thread
self
.
zero_thread
.
start
()
profile
.
StartTracer
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -1227,6 +1244,35 @@ class LLMEngine:
...
@@ -1227,6 +1244,35 @@ class LLMEngine:
return
None
return
None
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
output
[
0
].
sampler_out_ids
.
tolist
()
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
if
seq_group_metadata
.
do_sample
:
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
if
seq
.
seq_id
==
seq_id
:
if
type
(
token_id
)
is
list
:
sample
.
output_token
=
token_id
[
0
]
else
:
sample
.
output_token
=
token_id
seq
.
fix_last_token_id
(
sample
.
output_token
)
break
def
_advance_to_next_step
(
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
],
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -1270,6 +1316,131 @@ class LLMEngine:
...
@@ -1270,6 +1316,131 @@ class LLMEngine:
seq_group
.
update_num_computed_tokens
(
1
)
seq_group
.
update_num_computed_tokens
(
1
)
else
:
else
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
finish_thread
(
self
):
if
self
.
zero_overhead
and
self
.
thread_running
:
self
.
thread_running
=
False
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
logger
.
info
(
'zero overhead thread start!'
)
try
:
while
True
:
self
.
sem_m2s
.
acquire
()
if
not
self
.
thread_running
:
break
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
last_record
is
not
None
:
last_output
=
self
.
last_record
[
0
][
0
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
async_d2h
=
last_outputs_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
if
len
(
outputs
)
==
1
:
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
self
.
last_record
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
traceback
.
print_exc
()
def
zero_overhead_step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
if
not
self
.
thread_running
:
self
.
zero_thread
.
join
()
self
.
thread_running
=
True
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
zero_thread
.
start
()
self
.
sem_m2s
.
release
()
recode_output
=
self
.
q_recorder
.
get
()
if
recode_output
is
None
:
# None is for the first step
return
None
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self
.
_process_model_outputs
(
ctx
=
ctx
)
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
...
@@ -1322,6 +1493,13 @@ class LLMEngine:
...
@@ -1322,6 +1493,13 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
>>> break
"""
"""
#traceback.print_stack()
if
self
.
zero_overhead
:
out
=
self
.
zero_overhead_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overhead_step
()
return
out
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"Pipeline parallelism is only supported through AsyncLLMEngine "
...
@@ -1395,14 +1573,14 @@ class LLMEngine:
...
@@ -1395,14 +1573,14 @@ class LLMEngine:
# We use ExecuteModelRequest to pass the last sampled_token_ids
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
virtual_engine
]
#profile.ProfRangeAutoPush('model_executor')
outputs
=
self
.
model_executor
.
execute_model
(
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
#profile.ProfRangeAutoPush('end_executor')
# We need to do this here so that last step's sampled_token_ids can
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
...
@@ -1442,7 +1620,6 @@ class LLMEngine:
...
@@ -1442,7 +1620,6 @@ class LLMEngine:
if
outputs
and
allow_async_output_proc
:
if
outputs
and
allow_async_output_proc
:
assert
len
(
outputs
)
==
1
,
(
assert
len
(
outputs
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
...
@@ -1460,6 +1637,7 @@ class LLMEngine:
...
@@ -1460,6 +1637,7 @@ class LLMEngine:
# Multi-step case
# Multi-step case
return
ctx
.
request_outputs
return
ctx
.
request_outputs
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
...
@@ -1473,7 +1651,6 @@ class LLMEngine:
...
@@ -1473,7 +1651,6 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
return
ctx
.
request_outputs
def
_has_remaining_steps
(
def
_has_remaining_steps
(
...
...
vllm/engine/output_processor/stop_checker.py
View file @
f44e9f9e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -20,6 +21,7 @@ class StopChecker:
...
@@ -20,6 +21,7 @@ class StopChecker:
# Do not use it directly, but use `self._get_max_model_len`.
# Do not use it directly, but use `self._get_max_model_len`.
self
.
_max_model_len
=
max_model_len
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
if
lora_req
and
lora_req
.
long_lora_max_len
:
if
lora_req
and
lora_req
.
long_lora_max_len
:
...
@@ -42,12 +44,12 @@ class StopChecker:
...
@@ -42,12 +44,12 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
# skip the stop string/token checks if not
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
if
seq
.
get_output_len
(
self
.
zero_overhead
)
<
sampling_params
.
min_tokens
:
return
return
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
and
seq
.
get_last_token_id
(
self
.
zero_overhead
)
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
if
new_char_count
and
(
...
@@ -58,7 +60,7 @@ class StopChecker:
...
@@ -58,7 +60,7 @@ class StopChecker:
# Check if a stop token was encountered.
# Check if a stop token was encountered.
# This assumes a single token produced per step.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
...
@@ -81,12 +83,12 @@ class StopChecker:
...
@@ -81,12 +83,12 @@ class StopChecker:
return
return
# Check if the sequence has reached max_model_len.
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
if
seq
.
get_len
(
self
.
zero_overhead
)
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
# Check if the sequence has reached max_tokens.
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
if
seq
.
get_output_len
(
self
.
zero_overhead
)
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
...
...
vllm/entrypoints/llm.py
View file @
f44e9f9e
...
@@ -243,6 +243,9 @@ class LLM:
...
@@ -243,6 +243,9 @@ class LLM:
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
def
__del__
(
self
):
self
.
llm_engine
.
finish_thread
()
@
staticmethod
@
staticmethod
def
get_engine_class
()
->
Type
[
LLMEngine
]:
def
get_engine_class
()
->
Type
[
LLMEngine
]:
...
@@ -1408,6 +1411,8 @@ class LLM:
...
@@ -1408,6 +1411,8 @@ class LLM:
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
self
.
llm_engine
.
finish_thread
()
# Sort the outputs by request ID.
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# This is necessary because some requests may be finished earlier than
# its previous requests.
# its previous requests.
...
...
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
View file @
f44e9f9e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
re
import
re
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -44,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -44,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"DeepSeek R1 reasoning parser could not locate think start/end "
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
"tokens in the tokenizer!"
)
# TODO: need to rebase by PR #14428
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
think_end_token_id
in
input_ids
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract the content after the end tokens
"""
if
self
.
think_end_token_id
not
in
input_ids
[:
-
1
]:
return
[]
else
:
return
input_ids
[
input_ids
.
index
(
self
.
think_end_token_id
)
+
1
:]
def
extract_reasoning_content_streaming
(
def
extract_reasoning_content_streaming
(
self
,
self
,
previous_text
:
str
,
previous_text
:
str
,
...
@@ -67,6 +81,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -67,6 +81,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
]):
]):
return
None
return
None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
# <think> in previous, </think> in delta,
...
@@ -85,7 +101,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -85,7 +101,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_start_token_id
in
delta_token_ids
:
elif
self
.
think_start_token_id
in
delta_token_ids
:
logger
.
info
(
delta_text
)
if
self
.
think_end_token_id
in
delta_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
...
@@ -101,35 +116,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
...
@@ -101,35 +116,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
return
DeltaMessage
(
reasoning_content
=
delta_text
)
else
:
else
:
# No <think> in previous or delta, reasoning content continues.
# No <think> in previous or delta, also need to check for </think>.
return
DeltaMessage
(
content
=
delta_text
)
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_end_token_id
in
delta_token_ids
:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
# </think> in previous, thinking content ends
return
DeltaMessage
(
content
=
delta_text
)
else
:
# no </think> in previous or delta, reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
def
extract_reasoning_content
(
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
T
uple
[
Optional
[
str
],
Optional
[
str
]]:
)
->
t
uple
[
Optional
[
str
],
Optional
[
str
]]:
# Check if the model output contains the <think> tokens.
# DeepSeek R1 doesn't generate <think> now.
if
(
self
.
think_start_token
not
in
model_output
# Thus we assume the reasoning content is always at the start.
or
self
.
think_end_token
not
in
model_output
):
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
return
None
,
model_output
if
self
.
think_end_token
not
in
model_output
:
return
model_output
,
None
else
:
else
:
# Add a start token if it's missing to keep compatibility.
if
self
.
think_start_token
not
in
model_output
:
model_output
=
f
"
{
self
.
think_start_token
}{
model_output
}
"
# Use a regex to find the reasoning content
# Use a regex to find the reasoning content
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
# Remove the reasoning content from the model output
end_index
=
len
(
# Although deepseek's <think> token is always at the
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
# beginning of the line, we cannot guarantee that the
)
# other models will follow this convention.
final_output
=
model_output
[
end_index
:]
# Therefore, we need to add :start_index.
start_index
=
model_output
.
find
(
self
.
think_start_token
)
if
len
(
final_output
)
==
0
:
if
start_index
!=
-
1
:
return
reasoning_content
,
None
end_index
=
start_index
+
len
(
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
return
reasoning_content
,
final_output
)
\ No newline at end of file
model_output
=
model_output
[:
start_index
]
+
\
model_output
[
end_index
:]
if
len
(
model_output
)
==
0
:
return
reasoning_content
,
None
return
reasoning_content
,
model_output
vllm/envs.py
View file @
f44e9f9e
...
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
...
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_FLUX
:
bool
=
False
VLLM_USE_FLUX
:
bool
=
False
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
...
@@ -96,6 +97,7 @@ if TYPE_CHECKING:
...
@@ -96,6 +97,7 @@ if TYPE_CHECKING:
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_SPEC_DECODE_EAGER
:
bool
=
False
VLLM_SPEC_DECODE_EAGER
:
bool
=
False
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -246,6 +248,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -246,6 +248,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE"
:
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE"
,
"0"
))),
# flag to control vllm to use optimized tc paged attn kernels
# flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN"
:
"VLLM_USE_TC_PAGED_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TC_PAGED_ATTN"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TC_PAGED_ATTN"
,
"True"
).
lower
()
in
...
@@ -623,6 +629,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -623,6 +629,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the draft model in cudagraph mode.
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_SPEC_DECODE_EAGER"
:
"VLLM_SPEC_DECODE_EAGER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_SPEC_DECODE_EAGER"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_SPEC_DECODE_EAGER"
,
"0"
))),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
,
"-1"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment