Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
823ab796
Unverified
Commit
823ab796
authored
Jan 28, 2025
by
Harry Mellor
Committed by
GitHub
Jan 27, 2025
Browse files
Update `pre-commit` hooks (#12475)
Signed-off-by:
Harry Mellor
<
19981378+hmellor@users.noreply.github.com
>
parent
6116ca8c
Changes
64
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
79 additions
and
73 deletions
+79
-73
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+2
-2
vllm/attention/selector.py
vllm/attention/selector.py
+2
-2
vllm/config.py
vllm/config.py
+4
-3
vllm/core/block/common.py
vllm/core/block/common.py
+4
-3
vllm/core/block_manager.py
vllm/core/block_manager.py
+2
-2
vllm/core/scheduler.py
vllm/core/scheduler.py
+11
-12
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+4
-4
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+4
-4
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+2
-2
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+4
-5
vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
...ypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
+2
-2
vllm/lora/layers.py
vllm/lora/layers.py
+8
-4
vllm/lora/models.py
vllm/lora/models.py
+3
-2
vllm/lora/ops/triton_ops/sgmv_expand.py
vllm/lora/ops/triton_ops/sgmv_expand.py
+2
-3
vllm/lora/ops/triton_ops/sgmv_shrink.py
vllm/lora/ops/triton_ops/sgmv_shrink.py
+2
-2
vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py
...rs/quantization/kernels/mixed_precision/MPLinearKernel.py
+6
-6
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
...rs/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
+7
-7
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+4
-3
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-3
No files found.
vllm/attention/ops/triton_flash_attention.py
View file @
823ab796
...
@@ -627,8 +627,8 @@ def attn_fwd(
...
@@ -627,8 +627,8 @@ def attn_fwd(
causal_start_idx
,
causal_start_idx
,
dtype
=
tl
.
int32
)
dtype
=
tl
.
int32
)
mask_m_offsets
=
start_m_idx
+
tl
.
arange
(
0
,
BLOCK_M
)
mask_m_offsets
=
start_m_idx
+
tl
.
arange
(
0
,
BLOCK_M
)
out_ptrs_mask
=
(
mask_m_offsets
[:,
None
]
>=
out_ptrs_mask
=
(
mask_m_offsets
[:,
None
]
out_mask_boundary
[
None
,
:])
>=
out_mask_boundary
[
None
,
:])
z
=
0.0
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
# write back LSE
...
...
vllm/attention/selector.py
View file @
823ab796
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
lru_
cache
from
functools
import
cache
from
typing
import
Generator
,
Optional
,
Type
from
typing
import
Generator
,
Optional
,
Type
import
torch
import
torch
...
@@ -100,7 +100,7 @@ def get_attn_backend(
...
@@ -100,7 +100,7 @@ def get_attn_backend(
)
)
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
_cached_get_attn_backend
(
def
_cached_get_attn_backend
(
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
...
vllm/config.py
View file @
823ab796
...
@@ -67,7 +67,8 @@ _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
...
@@ -67,7 +67,8 @@ _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
_TASK_RUNNER
:
Dict
[
_ResolvedTask
,
RunnerType
]
=
{
_TASK_RUNNER
:
Dict
[
_ResolvedTask
,
RunnerType
]
=
{
task
:
runner
task
:
runner
for
runner
,
tasks
in
_RUNNER_TASKS
.
items
()
for
task
in
tasks
for
runner
,
tasks
in
_RUNNER_TASKS
.
items
()
for
task
in
tasks
}
}
HfOverrides
=
Union
[
Dict
[
str
,
Any
],
Callable
[[
PretrainedConfig
],
HfOverrides
=
Union
[
Dict
[
str
,
Any
],
Callable
[[
PretrainedConfig
],
...
@@ -1976,8 +1977,8 @@ class SpeculativeConfig:
...
@@ -1976,8 +1977,8 @@ class SpeculativeConfig:
"typical_acceptance_sampler."
)
"typical_acceptance_sampler."
)
if
(
self
.
draft_token_acceptance_method
!=
'rejection_sampler'
if
(
self
.
draft_token_acceptance_method
!=
'rejection_sampler'
and
self
.
draft_token_acceptance_method
!=
and
self
.
draft_token_acceptance_method
'typical_acceptance_sampler'
):
!=
'typical_acceptance_sampler'
):
raise
ValueError
(
raise
ValueError
(
"Expected draft_token_acceptance_method to be either "
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
"rejection_sampler or typical_acceptance_sampler. Instead it "
...
...
vllm/core/block/common.py
View file @
823ab796
...
@@ -34,9 +34,10 @@ class RefCounter(RefCounterProtocol):
...
@@ -34,9 +34,10 @@ class RefCounter(RefCounterProtocol):
def
__init__
(
self
,
all_block_indices
:
Iterable
[
BlockId
]):
def
__init__
(
self
,
all_block_indices
:
Iterable
[
BlockId
]):
deduped
=
set
(
all_block_indices
)
deduped
=
set
(
all_block_indices
)
self
.
_refcounts
:
Dict
[
BlockId
,
self
.
_refcounts
:
Dict
[
BlockId
,
RefCount
]
=
{
RefCount
]
=
{
index
:
0
index
:
0
for
index
in
deduped
}
for
index
in
deduped
}
def
incr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
def
incr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
assert
block_id
in
self
.
_refcounts
assert
block_id
in
self
.
_refcounts
...
...
vllm/core/block_manager.py
View file @
823ab796
...
@@ -136,8 +136,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -136,8 +136,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
device
=
Device
.
GPU
)
device
=
Device
.
GPU
)
# Use watermark to avoid frequent cache eviction.
# Use watermark to avoid frequent cache eviction.
if
(
self
.
num_total_gpu_blocks
-
num_required_blocks
<
if
(
self
.
num_total_gpu_blocks
-
num_required_blocks
self
.
watermark_blocks
):
<
self
.
watermark_blocks
):
return
AllocStatus
.
NEVER
return
AllocStatus
.
NEVER
if
num_free_gpu_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
:
if
num_free_gpu_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
:
return
AllocStatus
.
OK
return
AllocStatus
.
OK
...
...
vllm/core/scheduler.py
View file @
823ab796
...
@@ -988,8 +988,8 @@ class Scheduler:
...
@@ -988,8 +988,8 @@ class Scheduler:
waiting_queue
.
popleft
()
waiting_queue
.
popleft
()
continue
continue
if
(
budget
.
num_batched_tokens
>=
if
(
budget
.
num_batched_tokens
self
.
scheduler_config
.
max_num_batched_tokens
):
>=
self
.
scheduler_config
.
max_num_batched_tokens
):
# We've reached the budget limit - since there might be
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
# to avoid scheduling any new prefills.
...
@@ -1096,8 +1096,8 @@ class Scheduler:
...
@@ -1096,8 +1096,8 @@ class Scheduler:
running_scheduled
.
swapped_out
)
==
0
:
running_scheduled
.
swapped_out
)
==
0
:
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
assert
(
budget
.
num_batched_tokens
<=
assert
(
budget
.
num_batched_tokens
self
.
scheduler_config
.
max_num_batched_tokens
)
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
# Update waiting requests.
...
@@ -1189,8 +1189,8 @@ class Scheduler:
...
@@ -1189,8 +1189,8 @@ class Scheduler:
curr_loras
,
curr_loras
,
enable_chunking
=
True
)
enable_chunking
=
True
)
assert
(
budget
.
num_batched_tokens
<=
assert
(
budget
.
num_batched_tokens
self
.
scheduler_config
.
max_num_batched_tokens
)
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
# Update waiting requests.
...
@@ -1358,8 +1358,8 @@ class Scheduler:
...
@@ -1358,8 +1358,8 @@ class Scheduler:
# NOTE: We use get_len instead of get_prompt_len because when
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# a sequence is preempted, prefill includes previous generated
# output tokens.
# output tokens.
if
(
token_chunk_size
+
num_computed_tokens
<
if
(
token_chunk_size
+
num_computed_tokens
seqs
[
0
].
data
.
get_len
()):
<
seqs
[
0
].
data
.
get_len
()):
do_sample
=
False
do_sample
=
False
# It assumes the scheduled_seq_groups is ordered by
# It assumes the scheduled_seq_groups is ordered by
...
@@ -1625,10 +1625,9 @@ class Scheduler:
...
@@ -1625,10 +1625,9 @@ class Scheduler:
if
self
.
scheduler_config
.
delay_factor
>
0
and
self
.
waiting
:
if
self
.
scheduler_config
.
delay_factor
>
0
and
self
.
waiting
:
earliest_arrival_time
=
min
(
earliest_arrival_time
=
min
(
[
e
.
metrics
.
arrival_time
for
e
in
self
.
waiting
])
[
e
.
metrics
.
arrival_time
for
e
in
self
.
waiting
])
passed_delay
=
(
passed_delay
=
((
now
-
earliest_arrival_time
)
(
now
-
earliest_arrival_time
)
>
>
(
self
.
scheduler_config
.
delay_factor
*
(
self
.
scheduler_config
.
delay_factor
*
self
.
last_prompt_latency
)
self
.
last_prompt_latency
)
or
not
self
.
running
)
or
not
self
.
running
)
else
:
else
:
passed_delay
=
True
passed_delay
=
True
return
passed_delay
return
passed_delay
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
823ab796
...
@@ -352,8 +352,8 @@ class MessageQueue:
...
@@ -352,8 +352,8 @@ class MessageQueue:
sched_yield
()
sched_yield
()
# if we wait for a long time, log a message
# if we wait for a long time, log a message
if
(
time
.
monotonic
()
-
start_time
>
if
(
time
.
monotonic
()
-
start_time
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
logger
.
debug
(
"No available block found in %s second. "
,
logger
.
debug
(
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
n_warning
+=
1
...
@@ -410,8 +410,8 @@ class MessageQueue:
...
@@ -410,8 +410,8 @@ class MessageQueue:
sched_yield
()
sched_yield
()
# if we wait for a long time, log a message
# if we wait for a long time, log a message
if
(
time
.
monotonic
()
-
start_time
>
if
(
time
.
monotonic
()
-
start_time
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
):
logger
.
debug
(
"No available block found in %s second. "
,
logger
.
debug
(
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
n_warning
+=
1
...
...
vllm/distributed/parallel_state.py
View file @
823ab796
...
@@ -1014,8 +1014,8 @@ def initialize_model_parallel(
...
@@ -1014,8 +1014,8 @@ def initialize_model_parallel(
backend
=
backend
or
torch
.
distributed
.
get_backend
(
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
get_world_group
().
device_group
)
if
(
world_size
!=
if
(
world_size
tensor_model_parallel_size
*
pipeline_model_parallel_size
):
!=
tensor_model_parallel_size
*
pipeline_model_parallel_size
):
raise
RuntimeError
(
raise
RuntimeError
(
f
"world_size (
{
world_size
}
) is not equal to "
f
"world_size (
{
world_size
}
) is not equal to "
f
"tensor_model_parallel_size (
{
tensor_model_parallel_size
}
) x "
f
"tensor_model_parallel_size (
{
tensor_model_parallel_size
}
) x "
...
@@ -1069,8 +1069,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
...
@@ -1069,8 +1069,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
return
return
if
all
([
if
all
([
vllm_config
.
kv_transfer_config
.
need_kv_parallel_group
,
vllm_config
.
kv_transfer_config
.
need_kv_parallel_group
,
_KV_TRANSFER
_KV_TRANSFER
is
None
is
None
]):
]):
_KV_TRANSFER
=
kv_transfer
.
KVTransferAgent
(
_KV_TRANSFER
=
kv_transfer
.
KVTransferAgent
(
rank
=
get_world_group
().
rank
,
rank
=
get_world_group
().
rank
,
...
...
vllm/entrypoints/chat_utils.py
View file @
823ab796
...
@@ -3,7 +3,7 @@ import codecs
...
@@ -3,7 +3,7 @@ import codecs
import
json
import
json
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
functools
import
lru_cache
,
partial
from
functools
import
cache
,
lru_cache
,
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
...
@@ -377,7 +377,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -377,7 +377,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
self
.
_model_config
.
allowed_local_media_path
return
self
.
_model_config
.
allowed_local_media_path
@
staticmethod
@
staticmethod
@
lru_
cache
(
maxsize
=
None
)
@
cache
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
return
tokenizer
.
decode
(
token_index
)
return
tokenizer
.
decode
(
token_index
)
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
823ab796
...
@@ -522,8 +522,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -522,8 +522,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs
.
append
({
out_top_logprobs
.
append
({
# Convert float("-inf") to the
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
# JSON-serializable float that OpenAI uses
self
.
_get_decoded_token
(
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
1
],
top_lp
[
0
],
top_lp
[
0
],
tokenizer
,
tokenizer
,
return_as_token_id
=
self
.
return_tokens_as_token_ids
):
return_as_token_id
=
self
.
return_tokens_as_token_ids
):
...
...
vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
View file @
823ab796
...
@@ -62,8 +62,8 @@ class Granite20bFCToolParser(ToolParser):
...
@@ -62,8 +62,8 @@ class Granite20bFCToolParser(ToolParser):
start_of_json
=
match
.
end
()
start_of_json
=
match
.
end
()
# end_index == the start of the next function call
# end_index == the start of the next function call
# (if exists)
# (if exists)
next_function_call_start
=
(
matches
[
i
+
1
].
start
()
next_function_call_start
=
(
matches
[
i
+
1
].
start
()
if
i
+
if
i
+
1
<
len
(
matches
)
else
None
)
1
<
len
(
matches
)
else
None
)
raw_function_calls
.
append
(
raw_function_calls
.
append
(
dec
.
raw_decode
(
dec
.
raw_decode
(
...
...
vllm/lora/layers.py
View file @
823ab796
...
@@ -220,8 +220,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -220,8 +220,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_b
.
T
,
non_blocking
=
True
)
lora_b
.
T
,
non_blocking
=
True
)
if
embeddings_tensor
is
not
None
:
if
embeddings_tensor
is
not
None
:
self
.
embeddings_tensors
[
self
.
embeddings_tensors
[
index
,
:
embeddings_tensor
.
shape
[
0
],
:
embeddings_tensor
.
index
,
shape
[
1
],
].
copy_
(
embeddings_tensor
,
non_blocking
=
True
)
:
embeddings_tensor
.
shape
[
0
],
:
embeddings_tensor
.
shape
[
1
],
].
copy_
(
embeddings_tensor
,
non_blocking
=
True
)
if
self
.
embeddings_slice
is
not
None
:
if
self
.
embeddings_slice
is
not
None
:
# TODO(yard1): Optimize this copy, we don't need to copy
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
# everything, just the modified part
...
@@ -1024,8 +1026,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1024,8 +1026,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_b
.
T
,
non_blocking
=
True
)
lora_b
.
T
,
non_blocking
=
True
)
if
embeddings_tensor
is
not
None
:
if
embeddings_tensor
is
not
None
:
self
.
embeddings_tensors
[
self
.
embeddings_tensors
[
index
,
:
embeddings_tensor
.
shape
[
0
],
:
embeddings_tensor
.
index
,
shape
[
1
],
]
=
embeddings_tensor
:
embeddings_tensor
.
shape
[
0
],
:
embeddings_tensor
.
shape
[
1
],
]
=
embeddings_tensor
def
_get_logits
(
def
_get_logits
(
self
,
self
,
...
...
vllm/lora/models.py
View file @
823ab796
...
@@ -75,8 +75,9 @@ class LoRAModel(AdapterModel):
...
@@ -75,8 +75,9 @@ class LoRAModel(AdapterModel):
# Scaling factor for long context lora model. None if it is not
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
# fine tuned for the long context.
self
.
scaling_factor
=
scaling_factor
self
.
scaling_factor
=
scaling_factor
assert
(
lora_model_id
>
assert
(
0
),
f
"a valid lora id should be greater than 0, got
{
self
.
id
}
"
lora_model_id
>
0
),
f
"a valid lora id should be greater than 0, got
{
self
.
id
}
"
self
.
rank
=
rank
self
.
rank
=
rank
self
.
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
loras
self
.
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
loras
...
...
vllm/lora/ops/triton_ops/sgmv_expand.py
View file @
823ab796
...
@@ -136,9 +136,8 @@ def _sgmv_expand_kernel(
...
@@ -136,9 +136,8 @@ def _sgmv_expand_kernel(
c_ptr
=
(
out_ptr
+
offset_cm
[:,
None
]
*
output_d0_stride
+
c_ptr
=
(
out_ptr
+
offset_cm
[:,
None
]
*
output_d0_stride
+
offset_cn
[
None
,
:]
*
output_d1_stride
)
offset_cn
[
None
,
:]
*
output_d1_stride
)
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
c_mask
=
(
offset_cm
[:,
None
]
<
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
offset_cn
[
None
,
:]
<
(
cur_slice_start
+
curr_N
))
(
cur_slice_start
+
curr_N
))
if
ADD_INPUTS
:
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
tiled_c
+=
tiled_out
tiled_c
+=
tiled_out
...
...
vllm/lora/ops/triton_ops/sgmv_shrink.py
View file @
823ab796
...
@@ -114,8 +114,8 @@ def _sgmv_shrink_kernel(
...
@@ -114,8 +114,8 @@ def _sgmv_shrink_kernel(
slice_id
*
output_d0_stride
)
slice_id
*
output_d0_stride
)
c_ptr
=
cur_out_ptr
+
offset_cm
[:,
None
]
*
output_d1_stride
+
offset_cn
[
c_ptr
=
cur_out_ptr
+
offset_cm
[:,
None
]
*
output_d1_stride
+
offset_cn
[
None
,
:]
*
output_d2_stride
None
,
:]
*
output_d2_stride
c_mask
=
(
offset_cm
[:,
None
]
<
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
N
)
<
N
)
accumulator
*=
scaling
accumulator
*=
scaling
# handles write-back with reduction-splitting
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
if
SPLIT_K
==
1
:
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py
View file @
823ab796
...
@@ -73,8 +73,8 @@ class MPLinearKernel(ABC):
...
@@ -73,8 +73,8 @@ class MPLinearKernel(ABC):
torch
.
nn
.
Parameter
(
new_param
.
data
,
requires_grad
=
False
))
torch
.
nn
.
Parameter
(
new_param
.
data
,
requires_grad
=
False
))
def
_get_weight_params
(
def
_get_weight_params
(
self
,
layer
:
torch
.
nn
.
Module
self
,
layer
:
torch
.
nn
.
Module
)
->
Tuple
[
)
->
Tuple
[
torch
.
Tensor
,
# w_q
torch
.
Tensor
,
# w_q
torch
.
Tensor
,
# w_s
torch
.
Tensor
,
# w_s
Optional
[
torch
.
Tensor
],
# w_zp,
Optional
[
torch
.
Tensor
],
# w_zp,
Optional
[
torch
.
Tensor
]
# w_gidx
Optional
[
torch
.
Tensor
]
# w_gidx
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
View file @
823ab796
...
@@ -48,8 +48,8 @@ class ScaledMMLinearKernel(ABC):
...
@@ -48,8 +48,8 @@ class ScaledMMLinearKernel(ABC):
raise
NotImplementedError
raise
NotImplementedError
def
_get_weight_params
(
def
_get_weight_params
(
self
,
layer
:
torch
.
nn
.
Module
self
,
layer
:
torch
.
nn
.
Module
)
->
Tuple
[
)
->
Tuple
[
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
,
# weight_scale
Optional
[
torch
.
Tensor
],
# input_scale,
Optional
[
torch
.
Tensor
],
# input_scale,
Optional
[
torch
.
Tensor
],
# input_zp
Optional
[
torch
.
Tensor
],
# input_zp
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
823ab796
...
@@ -72,9 +72,10 @@ def block_quant_to_tensor_quant(
...
@@ -72,9 +72,10 @@ def block_quant_to_tensor_quant(
x_dq_block
=
x_q_block
.
to
(
torch
.
float32
)
x_dq_block
=
x_q_block
.
to
(
torch
.
float32
)
x_dq_block_tiles
=
[[
x_dq_block_tiles
=
[[
x_dq_block
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
x_dq_block
[
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
for
i
in
range
(
k_tiles
)
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)]
]
for
j
in
range
(
n_tiles
)]
for
i
in
range
(
k_tiles
):
for
i
in
range
(
k_tiles
):
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
823ab796
...
@@ -73,8 +73,8 @@ def requantize_with_max_scale(
...
@@ -73,8 +73,8 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since)
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
weight_scale
[
-
1
]
>
torch
.
finfo
(
unfused_module_in_checkpoint
=
(
weight_scale
[
-
1
]
torch
.
float8_e4m3fn
).
min
)
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
# If unfused checkpoint, need requanize with the single scale.
# If unfused checkpoint, need requanize with the single scale.
if
unfused_module_in_checkpoint
:
if
unfused_module_in_checkpoint
:
...
...
vllm/model_executor/layers/sampler.py
View file @
823ab796
...
@@ -716,9 +716,10 @@ def _sample_with_torch(
...
@@ -716,9 +716,10 @@ def _sample_with_torch(
tensors required for Pythonization
tensors required for Pythonization
'''
'''
categorized_seq_group_ids
:
Dict
[
SamplingType
,
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
List
[
int
]]
=
{
t
:
[]
t
:
[]
for
t
in
SamplingType
}
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment