Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1515 additions
and
310 deletions
+1515
-310
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+22
-53
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+155
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+57
-52
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+2
-2
vllm/inputs.py
vllm/inputs.py
+130
-0
vllm/logger.py
vllm/logger.py
+2
-1
vllm/lora/fully_sharded_layers.py
vllm/lora/fully_sharded_layers.py
+30
-24
vllm/lora/layers.py
vllm/lora/layers.py
+127
-8
vllm/lora/models.py
vllm/lora/models.py
+174
-20
vllm/lora/request.py
vllm/lora/request.py
+2
-0
vllm/lora/utils.py
vllm/lora/utils.py
+12
-5
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+40
-9
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+3
-1
vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
.../configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
+128
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
+110
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
+128
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
+128
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+90
-47
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+161
-85
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+14
-3
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
vllm/executor/gpu_executor.py
View file @
b9e12416
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
...
@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model.
"""Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
"""
"""
if
self
.
speculative_config
is
None
:
assert
self
.
parallel_config
.
world_size
==
1
,
(
self
.
_init_non_spec_worker
()
"GPUExecutor only supports single GPU."
)
else
:
self
.
_init_spec_worker
()
self
.
driver_worker
=
self
.
_create_worker
()
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
_get_worker_kwargs
(
def
_get_worker_kwargs
(
self
,
self
,
...
@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
...
@@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
vision_language_config
=
self
.
vision_language_config
,
speculative_config
=
self
.
speculative_config
,
is_driver_worker
=
rank
==
0
,
is_driver_worker
=
rank
==
0
,
)
)
...
@@ -52,53 +52,22 @@ class GPUExecutor(ExecutorBase):
...
@@ -52,53 +52,22 @@ class GPUExecutor(ExecutorBase):
local_rank
:
int
=
0
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
):
distributed_init_method
:
Optional
[
str
]
=
None
):
if
self
.
speculative_config
is
None
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
else
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
wrapper
=
WorkerWrapperBase
(
wrapper
=
WorkerWrapperBase
(
worker_module_name
=
"vllm.worker.worker"
,
worker_module_name
=
worker_module_name
,
worker_class_name
=
"W
orker
"
,
worker_class_name
=
w
orker
_class_name
,
)
)
wrapper
.
init_worker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
wrapper
.
init_worker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
distributed_init_method
))
return
wrapper
.
worker
return
wrapper
.
worker
def
_init_non_spec_worker
(
self
):
assert
self
.
parallel_config
.
world_size
==
1
,
(
"GPUExecutor only supports single GPU."
)
self
.
driver_worker
=
self
.
_create_worker
()
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
_init_spec_worker
(
self
):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert
self
.
speculative_config
is
not
None
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
target_worker
=
self
.
_create_worker
()
draft_worker_kwargs
=
self
.
_get_worker_kwargs
()
# Override draft-model specific worker args.
draft_worker_kwargs
.
update
(
model_config
=
self
.
speculative_config
.
draft_model_config
,
parallel_config
=
self
.
speculative_config
.
draft_parallel_config
,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
)
assert
self
.
parallel_config
.
world_size
==
1
,
(
"GPUExecutor only supports single GPU."
)
self
.
driver_worker
=
spec_decode_worker
# Load model handled in spec decode worker.
self
.
driver_worker
.
init_device
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks by invoking the
"""Determine the number of available KV blocks by invoking the
underlying worker.
underlying worker.
...
@@ -117,8 +86,8 @@ class GPUExecutor(ExecutorBase):
...
@@ -117,8 +86,8 @@ class GPUExecutor(ExecutorBase):
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
def
execute_model
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
return
output
...
@@ -144,7 +113,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
...
@@ -144,7 +113,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
async
def
execute_model_async
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
)(
execute_model_req
=
execute_model_req
,
)
return
output
return
output
vllm/executor/multiproc_gpu_executor.py
0 → 100644
View file @
b9e12416
import
asyncio
import
os
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
from
vllm.executor.distributed_gpu_executor
import
(
# yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
logger
=
init_logger
(
__name__
)
class
MultiprocessingGPUExecutor
(
DistributedGPUExecutor
):
"""Python multiprocessing-based multi-GPU executor"""
def
_init_executor
(
self
)
->
None
:
assert
(
not
self
.
speculative_config
),
"Speculative decoding not yet supported for MultiProcGPU backend."
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os
.
environ
[
"VLLM_INSTANCE_ID"
]
=
get_vllm_instance_id
()
from
torch.cuda
import
device_count
assert
world_size
<=
device_count
(),
(
"please set tensor_parallel_size to less than max local gpu count"
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
if
world_size
==
1
:
self
.
workers
=
[]
else
:
result_handler
=
ResultHandler
()
self
.
workers
=
[
ProcessWorkerWrapper
(
result_handler
,
partial
(
self
.
_create_worker
,
rank
=
rank
,
local_rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
))
for
rank
in
range
(
1
,
world_size
)
]
self
.
worker_monitor
=
WorkerMonitor
(
self
.
workers
,
result_handler
)
result_handler
.
start
()
self
.
worker_monitor
.
start
()
self
.
driver_worker
=
self
.
_create_worker
(
distributed_init_method
=
distributed_init_method
)
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
def
shutdown
(
self
):
if
(
worker_monitor
:
=
getattr
(
self
,
"worker_monitor"
,
None
))
is
not
None
:
worker_monitor
.
close
()
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return
self
.
driver_worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
async_run_remote_workers_only
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
# Start the workers first.
worker_outputs
=
[
worker
.
execute_method
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
async_run_remote_workers_only
:
# Just return futures
return
worker_outputs
driver_worker_method
=
getattr
(
self
.
driver_worker
,
method
)
driver_worker_output
=
driver_worker_method
(
*
args
,
**
kwargs
)
# Get the results of the workers.
return
[
driver_worker_output
]
+
[
output
.
get
()
for
output
in
worker_outputs
]
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
if
not
self
.
worker_monitor
.
is_alive
():
raise
RuntimeError
(
"Worker processes are not running"
)
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for
result
in
parallel_worker_tasks
:
result
.
get
()
class
MultiprocessingGPUExecutorAsync
(
MultiprocessingGPUExecutor
,
DistributedGPUExecutorAsync
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_model
=
make_async
(
self
.
driver_worker
.
execute_model
)
async
def
_driver_execute_model_async
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
return
await
self
.
driver_exec_model
(
execute_model_req
)
async
def
_start_worker_execution_loop
(
self
):
coros
=
[
worker
.
execute_method_async
(
"start_worker_execution_loop"
)
for
worker
in
self
.
workers
]
return
await
asyncio
.
gather
(
*
coros
)
vllm/executor/ray_gpu_executor.py
View file @
b9e12416
...
@@ -28,10 +28,7 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
...
@@ -28,10 +28,7 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class
RayGPUExecutor
(
DistributedGPUExecutor
):
class
RayGPUExecutor
(
DistributedGPUExecutor
):
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
(
not
self
.
speculative_config
assert
self
.
parallel_config
.
distributed_executor_backend
==
"ray"
),
"Speculative decoding not yet supported for RayGPU backend."
assert
self
.
parallel_config
.
worker_use_ray
placement_group
=
self
.
parallel_config
.
placement_group
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
# Disable Ray usage stats collection.
...
@@ -45,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -45,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
self
.
forward_dag
=
None
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
self
.
extra_execute_model_run_workers_kwargs
[
"use_ray_compiled_dag"
]
=
True
def
_configure_ray_workers_use_nsight
(
self
,
def
_configure_ray_workers_use_nsight
(
self
,
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
...
@@ -90,14 +89,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -90,14 +89,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks
=
True
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
placement_group_bundle_index
=
bundle_id
,
)
)
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
worker
=
ray
.
remote
(
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
)(
RayWorkerWrapper
).
remote
(
worker_module_name
=
"vllm.worker.worker"
,
worker_module_name
=
worker_module_name
,
worker_class_name
=
"W
orker
"
,
worker_class_name
=
w
orker
_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)
...
@@ -107,8 +114,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -107,8 +114,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process.
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
self
.
driver_worker
=
RayWorkerWrapper
(
worker_module_name
=
"vllm.worker.worker"
,
worker_module_name
=
worker_module_name
,
worker_class_name
=
"W
orker
"
,
worker_class_name
=
w
orker
_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)
else
:
else
:
...
@@ -166,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -166,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers
=
self
.
parallel_config
.
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
max_parallel_loading_workers
)
def
execute_model
(
def
_driver_execute_model
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
all_outputs
=
self
.
_run_workers
(
)
->
List
[
SamplerOutput
]:
"execute_model"
,
"""Run execute_model in the driver worker.
driver_kwargs
=
{
"execute_model_req"
:
execute_model_req
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
Passing None will cause the driver to stop the model execution
return
all_outputs
[
0
]
loop running in each of the remote workers.
"""
return
self
.
driver_worker
.
execute_method
(
"execute_model"
,
execute_model_req
)
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
str
,
*
args
,
*
args
,
driver_args
:
Optional
[
Tuple
[
Any
,
...]]
=
None
,
async_run_remote_workers_only
:
bool
=
False
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
all_args
:
Optional
[
List
[
Tuple
[
Any
,
...]]]
=
None
,
all_args
:
Optional
[
List
[
Tuple
[
Any
,
...]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
use_dummy_driver
:
bool
=
False
,
...
@@ -193,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -193,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following
"""Runs the given method on all workers. Can be used in the following
ways:
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
individually
"""
"""
...
@@ -204,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -204,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise
NotImplementedError
(
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
"max_concurrent_workers is not supported yet."
)
if
driver_args
is
None
:
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
count
=
len
(
self
.
workers
)
count
=
len
(
self
.
workers
)
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
1
,
None
)
else
islice
(
all_args
,
1
,
None
)
...
@@ -220,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -220,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# input. TODO(sang): Fix it.
# input. TODO(sang): Fix it.
assert
self
.
forward_dag
is
not
None
assert
self
.
forward_dag
is
not
None
output_channels
=
self
.
forward_dag
.
execute
(
1
)
output_channels
=
self
.
forward_dag
.
execute
(
1
)
ray_worker_outputs
=
[]
else
:
else
:
# Start the ray workers first.
# Start the ray workers first.
ray_worker_outputs
=
[
ray_worker_outputs
=
[
...
@@ -229,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -229,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
]
if
async_run_remote_workers_only
:
# Just return futures
return
ray_worker_outputs
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
...
@@ -255,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -255,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
return
[
driver_worker_output
]
+
ray_worker_outputs
return
[
driver_worker_output
]
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
def
_compiled_ray_dag
(
self
):
def
_compiled_ray_dag
(
self
):
import
pkg_resources
import
pkg_resources
required_version
=
"2.9"
required_version
=
"2.9"
...
@@ -264,7 +281,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -264,7 +281,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
worker_use_
ray
assert
self
.
parallel_config
.
distributed_executor_backend
==
"
ray
"
# Right now, compiled DAG requires at least 1 arg. We send
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
# a dummy value for now. It will be fixed soon.
...
@@ -298,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
...
@@ -298,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec
utor
=
make_async
(
self
.
driver_worker
.
execute_method
)
self
.
driver_exec
_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
async
def
_
run_workers
_async
(
async
def
_
driver_execute_model
_async
(
self
,
self
,
method
:
str
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
*
args
,
)
->
List
[
SamplerOutput
]:
driver_args
:
Optional
[
Tuple
[
Any
,
...]]
=
None
,
return
await
self
.
driver_exec_method
(
"execute_model"
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
execute_model_req
)
**
kwargs
,
)
->
Any
:
async
def
_start_worker_execution_loop
(
self
):
"""Runs the given method on all workers."""
coros
=
[
coros
=
[]
worker
.
execute_method
.
remote
(
"start_worker_execution_loop"
)
for
worker
in
self
.
workers
if
driver_args
is
None
:
]
driver_args
=
args
return
await
asyncio
.
gather
(
*
coros
)
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
coros
.
append
(
self
.
driver_executor
(
method
,
*
driver_args
,
**
driver_kwargs
))
# Run the ray workers asynchronously.
for
worker
in
self
.
workers
:
coros
.
append
(
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
))
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
return
all_outputs
vllm/executor/ray_utils.py
View file @
b9e12416
...
@@ -44,7 +44,7 @@ try:
...
@@ -44,7 +44,7 @@ try:
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
logger
.
warning
(
"Failed to import Ray with %r. For
distributed
inference, "
"Failed to import Ray with %r. For
multi-node
inference, "
"please install Ray with `pip install ray`."
,
e
)
"please install Ray with `pip install ray`."
,
e
)
ray
=
None
# type: ignore
ray
=
None
# type: ignore
RayWorkerWrapper
=
None
# type: ignore
RayWorkerWrapper
=
None
# type: ignore
...
@@ -67,7 +67,7 @@ def initialize_ray_cluster(
...
@@ -67,7 +67,7 @@ def initialize_ray_cluster(
"""
"""
if
ray
is
None
:
if
ray
is
None
:
raise
ImportError
(
raise
ImportError
(
"Ray is not installed. Please install Ray to use
distributed
"
"Ray is not installed. Please install Ray to use
multi-node
"
"serving."
)
"serving."
)
# Connect to a ray cluster.
# Connect to a ray cluster.
...
...
vllm/inputs.py
0 → 100644
View file @
b9e12416
from
typing
import
(
TYPE_CHECKING
,
List
,
Literal
,
Optional
,
Sequence
,
TypedDict
,
Union
,
cast
,
overload
)
from
typing_extensions
import
NotRequired
if
TYPE_CHECKING
:
from
vllm.sequence
import
MultiModalData
class
ParsedText
(
TypedDict
):
content
:
str
is_tokens
:
Literal
[
False
]
class
ParsedTokens
(
TypedDict
):
content
:
List
[
int
]
is_tokens
:
Literal
[
True
]
# https://github.com/vllm-project/vllm/pull/4028
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
]])
->
Sequence
[
ParsedText
]:
...
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]]])
->
Sequence
[
ParsedTokens
]:
...
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
)
->
Union
[
Sequence
[
ParsedText
],
Sequence
[
ParsedTokens
]]:
if
isinstance
(
prompt
,
str
):
# case 1: a string
return
[
ParsedText
(
content
=
prompt
,
is_tokens
=
False
)]
if
isinstance
(
prompt
,
list
):
if
len
(
prompt
)
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
isinstance
(
prompt
[
0
],
str
):
# case 2: array of strings
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
cast
(
List
[
str
],
prompt
)
]
if
isinstance
(
prompt
[
0
],
int
):
# case 3: array of tokens
elem
=
cast
(
List
[
int
],
prompt
)
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)]
if
isinstance
(
prompt
[
0
],
list
):
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
isinstance
(
prompt
[
0
][
0
],
int
):
# case 4: array of token arrays
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)
for
elem
in
cast
(
List
[
List
[
int
]],
prompt
)
]
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
class
TextPrompt
(
TypedDict
):
"""Schema for a text prompt."""
prompt
:
str
"""The input text to be tokenized before passing to the model."""
multi_modal_data
:
NotRequired
[
"MultiModalData"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class
TokensPrompt
(
TypedDict
):
"""Schema for a tokenized prompt."""
prompt_token_ids
:
List
[
int
]
"""A list of token IDs to pass to the model."""
multi_modal_data
:
NotRequired
[
"MultiModalData"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class
TextTokensPrompt
(
TypedDict
):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt
:
str
"""The prompt text."""
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data
:
NotRequired
[
"MultiModalData"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
TextTokensPrompt
]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class
LLMInputs
(
TypedDict
):
prompt_token_ids
:
List
[
int
]
prompt
:
NotRequired
[
Optional
[
str
]]
multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalData"
]]
vllm/logger.py
View file @
b9e12416
...
@@ -14,6 +14,7 @@ import vllm.envs as envs
...
@@ -14,6 +14,7 @@ import vllm.envs as envs
VLLM_CONFIGURE_LOGGING
=
envs
.
VLLM_CONFIGURE_LOGGING
VLLM_CONFIGURE_LOGGING
=
envs
.
VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH
=
envs
.
VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_CONFIG_PATH
=
envs
.
VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_LEVEL
=
envs
.
VLLM_LOGGING_LEVEL
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
...
@@ -30,7 +31,7 @@ DEFAULT_LOGGING_CONFIG = {
...
@@ -30,7 +31,7 @@ DEFAULT_LOGGING_CONFIG = {
"vllm"
:
{
"vllm"
:
{
"class"
:
"logging.StreamHandler"
,
"class"
:
"logging.StreamHandler"
,
"formatter"
:
"vllm"
,
"formatter"
:
"vllm"
,
"level"
:
"INFO"
,
"level"
:
VLLM_LOGGING_LEVEL
,
"stream"
:
"ext://sys.stdout"
,
"stream"
:
"ext://sys.stdout"
,
},
},
},
},
...
...
vllm/lora/fully_sharded_layers.py
View file @
b9e12416
# pylint: disable=unused-argument
# pylint: disable=unused-argument
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
return
lora_a
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
,
out_orig_shape
=
output
.
view
(
-
1
,
...
@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
)
def
_mcp_apply
_weights
(
x
,
bias
,
layer
):
def
_mcp_apply
(
x
,
bias
,
layer
):
"""
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
QKVParallelLinearWithShardedLora share the same
...
@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
...
@@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
"""
"""
# expecting 2 for column parallel and 3 for qkv
# expecting 2 for column parallel and 3 for qkv
n
=
len
(
layer
.
lora_a_stacked
)
n
=
len
(
layer
.
lora_a_stacked
)
output
=
layer
.
base_layer
.
linear_method
.
apply_weights
(
output
=
layer
.
base_layer
.
quant_method
.
apply
(
layer
.
base_layer
,
x
,
bias
)
layer
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
...
@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
...
@@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
Based on S-LoRA, slicing happens along the rank dim.
Based on S-LoRA, slicing happens along the rank dim.
"""
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
if
lora_a
[
0
]
is
None
or
lora_a
[
1
]
is
None
:
return
lora_a
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_start_idx
=
self
.
tp_rank
*
output_shard_size
output_start_idx
=
self
.
tp_rank
*
output_shard_size
lora_a
=
[
lora_a
=
[
lora_a
[
i
][:,
output_start_idx
:
output_start_idx
+
output_shard_size
]
lora_a
[
0
][:,
for
i
in
range
(
2
)
output_start_idx
:
output_start_idx
+
output_shard_size
],
lora_a
[
1
][:,
output_start_idx
:
output_start_idx
+
output_shard_size
]
]
]
return
lora_a
return
lora_a
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply
_weights
(
x
,
bias
,
self
)
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
@
classmethod
@
_fully_sharded_can_replace
@
_fully_sharded_can_replace
...
@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
...
@@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim.
Based on S-LoRA, slicing happens along the rank dim.
"""
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
if
lora_a
[
0
]
is
None
or
lora_a
[
1
]
is
None
or
lora_a
[
2
]
is
None
:
return
lora_a
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
lora_a
=
[
lora_a
=
[
lora_a
[
i
][:,
start_idx
[
i
]:
start_idx
[
i
]
+
lora_a
[
0
][:,
start_idx
[
0
]:
start_idx
[
0
]
+
shard_size
[
0
]],
shard_size
[
i
]]
if
lora_a
[
i
]
is
not
None
else
None
lora_a
[
1
][:,
start_idx
[
1
]:
start_idx
[
1
]
+
shard_size
[
1
]]
,
f
or
i
in
range
(
3
)
l
or
a_a
[
2
][:,
start_idx
[
2
]:
start_idx
[
2
]
+
shard_size
[
2
]]
]
]
return
lora_a
return
lora_a
def
apply
_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply
_weights
(
x
,
bias
,
self
)
return
_mcp_apply
(
x
,
bias
,
self
)
@
classmethod
@
classmethod
@
_fully_sharded_can_replace
@
_fully_sharded_can_replace
...
@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
...
@@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
return
lora_b
return
lora_b
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
self
.
base_layer
,
x
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
,
out_orig_shape
=
output
.
view
(
-
1
,
...
...
vllm/lora/layers.py
View file @
b9e12416
# pylint: disable=unused-argument
# pylint: disable=unused-argument
import
math
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
(
LinearScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -145,11 +147,15 @@ class LoRAMapping:
...
@@ -145,11 +147,15 @@ class LoRAMapping:
class
BaseLayerWithLoRA
(
nn
.
Module
):
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_lora_a
(
self
,
lora_a
:
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]
)
->
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]:
"""Slice lora a if splitting for tensor parallelism."""
"""Slice lora a if splitting for tensor parallelism."""
...
...
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_lora_b
(
self
,
lora_b
:
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]
)
->
Union
[
torch
.
Tensor
,
List
[
Union
[
torch
.
Tensor
,
None
]]]:
"""Slice lora b if splitting with tensor parallelism."""
"""Slice lora b if splitting with tensor parallelism."""
...
...
...
@@ -181,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module):
...
@@ -181,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module):
sampler_indices
:
torch
.
Tensor
,
sampler_indices
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
long_lora_indices
:
torch
.
Tensor
,
indices_len
:
List
[
int
],
indices_len
:
List
[
int
],
):
):
"""Sets the mapping indices."""
"""Sets the mapping indices."""
...
@@ -302,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -302,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
sampler_indices
:
torch
.
Tensor
,
sampler_indices
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
long_lora_indices
:
torch
.
Tensor
,
indices_len
:
List
[
int
],
indices_len
:
List
[
int
],
):
):
self
.
indices
=
base_indices
self
.
indices
=
base_indices
...
@@ -427,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -427,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices
:
torch
.
Tensor
,
sampler_indices
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
long_lora_indices
:
torch
.
Tensor
,
indices_len
:
List
[
int
],
indices_len
:
List
[
int
],
):
):
self
.
indices
=
base_indices
self
.
indices
=
base_indices
...
@@ -539,10 +548,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -539,10 +548,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
return
lora_a
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_b
(
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
if
lora_b
[
0
]
is
None
or
lora_b
[
1
]
is
None
:
return
lora_b
shard_size
=
self
.
output_dim
shard_size
=
self
.
output_dim
start_idx
=
self
.
tp_rank
*
shard_size
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
...
@@ -767,10 +782,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -767,10 +782,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_a
(
self
,
lora_a
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
return
lora_a
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
def
slice_lora_b
(
self
,
lora_b
:
List
[
Union
[
torch
.
Tensor
,
None
]]
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
lora_b_q
,
lora_b_k
,
lora_b_v
=
None
,
None
,
None
if
lora_b
[
0
]
is
not
None
:
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
...
@@ -936,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -936,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices
:
torch
.
Tensor
,
sampler_indices
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
long_lora_indices
:
torch
.
Tensor
,
indices_len
:
List
[
int
],
indices_len
:
List
[
int
],
):
):
self
.
indices
=
base_indices
self
.
indices
=
base_indices
...
@@ -992,7 +1013,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -992,7 +1013,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@
property
@
property
def
weight
(
self
):
def
weight
(
self
):
return
self
.
base_layer
.
weight
if
hasattr
(
return
self
.
base_layer
.
weight
if
hasattr
(
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
...
@@ -1113,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1113,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
sampler_indices
:
torch
.
Tensor
,
sampler_indices
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
long_lora_indices
:
torch
.
Tensor
,
indices_len
:
List
[
int
],
indices_len
:
List
[
int
],
):
):
self
.
indices
=
sampler_indices
self
.
indices
=
sampler_indices
...
@@ -1179,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1179,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# Special handling for the LogitsProcessor.
# Special handling for the LogitsProcessor.
return
False
return
False
class
LinearScalingRotaryEmbeddingWithLora
(
BaseLayerWithLoRA
):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialied kernel.
"""
def
__init__
(
self
,
base_layer
:
RotaryEmbedding
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
# Lazily initialized
self
.
long_lora_indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
@
property
def
scaling_factors
(
self
):
return
self
.
base_layer
.
scaling_factors
@
property
def
rotary_dim
(
self
):
return
self
.
base_layer
.
rotary_dim
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
scaling_factors
=
list
(
lora_config
.
long_lora_scaling_factors
)
if
lora_config
.
long_lora_scaling_factors
else
[]
base_scaling_factor
=
(
self
.
base_layer
.
scaling_factor
if
isinstance
(
self
.
base_layer
,
LinearScalingRotaryEmbedding
)
else
1.0
)
scaling_factors
=
sorted
(
list
(
set
([
base_scaling_factor
]
+
scaling_factors
)))
self
.
base_layer
=
LinearScalingRotaryEmbedding
(
self
.
base_layer
.
head_size
,
self
.
base_layer
.
rotary_dim
,
self
.
base_layer
.
max_position_embeddings
,
self
.
base_layer
.
base
,
self
.
base_layer
.
is_neox_style
,
scaling_factors
,
self
.
base_layer
.
dtype
,
)
def
reset_lora
(
self
,
index
:
int
):
...
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
...
def
set_mapping
(
self
,
base_indices
:
torch
.
Tensor
,
sampler_indices
:
torch
.
Tensor
,
sampler_indices_padded
:
torch
.
Tensor
,
embeddings_indices
:
torch
.
Tensor
,
long_lora_indices
:
torch
.
Tensor
,
indices_len
:
List
[
int
],
):
self
.
long_lora_indices
=
long_lora_indices
self
.
indices_len
=
indices_len
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
base_layer
(
positions
,
query
,
key
,
offsets
=
self
.
long_lora_indices
[:
self
.
indices_len
[
4
]])
@
property
def
scaling_factor_to_offset
(
self
)
->
Dict
[
float
,
int
]:
return
self
.
base_layer
.
scaling_factor_to_offset
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
return
type
(
source_layer
)
is
LinearScalingRotaryEmbedding
or
type
(
source_layer
)
is
RotaryEmbedding
def
extra_repr
(
self
)
->
str
:
return
self
.
base_layer
.
extra_repr
()
vllm/lora/models.py
View file @
b9e12416
...
@@ -3,7 +3,8 @@ import json
...
@@ -3,7 +3,8 @@ import json
import
math
import
math
import
os
import
os
import
re
import
re
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
safetensors.torch
import
safetensors.torch
import
torch
import
torch
...
@@ -11,7 +12,9 @@ from torch import nn
...
@@ -11,7 +12,9 @@ from torch import nn
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
LoRAMapping
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
LinearScalingRotaryEmbeddingWithLora
,
LoRAMapping
)
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
...
@@ -22,10 +25,27 @@ logger = init_logger(__name__)
...
@@ -22,10 +25,27 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID
=
0
_GLOBAL_LORA_ID
=
0
@
dataclass
class
LongContextLoRAContext
:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors
:
List
[
float
]
# dimension to apply rotary embedding.
rot_dim
:
int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
def
convert_mapping
(
def
convert_mapping
(
mapping
:
LoRAMapping
,
lora_index_to_id
:
List
[
Optional
[
int
]],
mapping
:
LoRAMapping
,
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
lora_index_to_id
:
List
[
Optional
[
int
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
List
[
int
]]:
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
List
[
int
]]:
"""Converts LoRAMapping to index tensors.
"""Converts LoRAMapping to index tensors.
Args:
Args:
...
@@ -34,6 +54,7 @@ def convert_mapping(
...
@@ -34,6 +54,7 @@ def convert_mapping(
max_loras: Maximum number of LoRAs.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
Returns:
A tuple of tensors:
A tuple of tensors:
...
@@ -51,11 +72,23 @@ def convert_mapping(
...
@@ -51,11 +72,23 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
"""
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
lora_indices
=
index_mapping_indices
.
copy
()
lora_indices
=
index_mapping_indices
.
copy
()
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
if
long_lora_context
:
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
prompt_mapping
:
List
[
int
]
=
[
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
for
x
in
mapping
.
prompt_mapping
...
@@ -66,13 +99,20 @@ def convert_mapping(
...
@@ -66,13 +99,20 @@ def convert_mapping(
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_indices
[
i
])
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_indices
[
i
])
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
index_mapping_indices
[
i
]
=
i
lora_indices
[
i
]
=
lora_idx
lora_indices
[
i
]
=
lora_idx
if
long_lora_context
:
indices
=
torch
.
tensor
(
assert
long_lora_offsets
is
not
None
[
index_mapping_indices
,
lora_indices
,
embedding_indices
],
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
dtype
=
torch
.
long
,
index_mapping_indices
[
i
],
0
)
device
=
"cuda"
)
long_lora_offsets
[
i
]
=
lora_offset
indices_list
:
List
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
[
index_mapping_indices
,
lora_indices
,
embedding_indices
]
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
indices_list
.
append
(
long_lora_offsets
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
...
@@ -89,13 +129,21 @@ def convert_mapping(
...
@@ -89,13 +129,21 @@ def convert_mapping(
torch
.
arange
(
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
)))
(
sampler_indices_padded
*
len
(
sampler_indices_padded
)))
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
if
long_lora_context
:
long_lora_indices
=
indices
[
3
]
long_lora_indices_len
=
long_lora_indices
.
shape
[
-
1
]
# Contain length of indices tensors. Used to index into each tensor.
indices_len
=
[
indices_len
=
[
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
]
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
]
]
]
if
long_lora_indices_len
is
not
None
:
indices_len
.
append
(
long_lora_indices_len
)
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
indices_len
)
embeddings_indices
,
long_lora_indices
,
indices_len
)
def
get_lora_id
():
def
get_lora_id
():
...
@@ -112,13 +160,35 @@ class LoRAModel:
...
@@ -112,13 +160,35 @@ class LoRAModel:
lora_model_id
:
int
,
lora_model_id
:
int
,
rank
:
int
,
rank
:
int
,
loras
:
Dict
[
str
,
LoRALayerWeights
],
loras
:
Dict
[
str
,
LoRALayerWeights
],
scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self
.
id
=
lora_model_id
self
.
id
=
lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self
.
scaling_factor
=
scaling_factor
assert
(
lora_model_id
>
assert
(
lora_model_id
>
0
),
f
"a valid lora id should be greater than 0, got
{
self
.
id
}
"
0
),
f
"a valid lora id should be greater than 0, got
{
self
.
id
}
"
self
.
rank
=
rank
self
.
rank
=
rank
self
.
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
loras
self
.
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
loras
def
clone
(
self
,
lora_model_id
:
int
)
->
"LoRAModel"
:
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return
self
.
__class__
(
lora_model_id
,
rank
=
self
.
rank
,
loras
=
self
.
loras
.
copy
(),
)
@
property
@
property
def
extra_vocab_size
(
self
)
->
int
:
def
extra_vocab_size
(
self
)
->
int
:
return
max
(
lora
.
extra_vocab_size
return
max
(
lora
.
extra_vocab_size
...
@@ -140,6 +210,7 @@ class LoRAModel:
...
@@ -140,6 +210,7 @@ class LoRAModel:
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
embeddings
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
embeddings
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
target_embedding_padding
:
Optional
[
int
]
=
None
,
target_embedding_padding
:
Optional
[
int
]
=
None
,
scaling_factor
:
Optional
[
float
]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
)
->
"LoRAModel"
:
)
->
"LoRAModel"
:
...
@@ -189,13 +260,15 @@ class LoRAModel:
...
@@ -189,13 +260,15 @@ class LoRAModel:
for
lora
in
loras
.
values
():
for
lora
in
loras
.
values
():
lora
.
optimize
()
lora
.
optimize
()
return
cls
(
lora_model_id
,
rank
,
loras
)
return
cls
(
lora_model_id
,
rank
,
loras
,
scaling_factor
=
scaling_factor
)
@
classmethod
@
classmethod
def
from_local_checkpoint
(
def
from_local_checkpoint
(
cls
,
cls
,
lora_dir
:
str
,
lora_dir
:
str
,
expected_lora_modules
:
List
[
str
],
expected_lora_modules
:
List
[
str
],
*
,
max_position_embeddings
:
Optional
[
int
]
=
None
,
lora_model_id
:
Optional
[
int
]
=
None
,
lora_model_id
:
Optional
[
int
]
=
None
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
@@ -203,7 +276,23 @@ class LoRAModel:
...
@@ -203,7 +276,23 @@ class LoRAModel:
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
List
[
str
]]
=
None
,
)
->
"LoRAModel"
:
)
->
"LoRAModel"
:
"""Create a LoRAModel from a local checkpoint."""
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_config_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_config.json"
)
lora_config_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_config.json"
)
lora_tensor_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.safetensors"
)
lora_tensor_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.safetensors"
)
lora_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.bin"
)
lora_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.bin"
)
...
@@ -221,7 +310,9 @@ class LoRAModel:
...
@@ -221,7 +310,9 @@ class LoRAModel:
if
part_name
not
in
expected_lora_modules
:
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module
)
unexpected_modules
.
append
(
module
)
# loaded lora's target modules must be a subset of expected_lora_modules
# loaded lora's target modules must be a subset of expected_lora_modules
if
unexpected_modules
:
if
unexpected_modules
:
print
(
unexpected_modules
,
"modules"
)
raise
ValueError
(
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
f
" target modules in
{
expected_lora_modules
}
"
...
@@ -243,6 +334,14 @@ class LoRAModel:
...
@@ -243,6 +334,14 @@ class LoRAModel:
rank
=
config
[
"r"
]
rank
=
config
[
"r"
]
lora_alpha
=
config
[
"lora_alpha"
]
lora_alpha
=
config
[
"lora_alpha"
]
context_length
=
config
.
get
(
"context_length"
,
None
)
scaling_factor
=
None
if
context_length
:
if
max_position_embeddings
is
None
:
max_position_embeddings
=
context_length
scaling_factor
=
float
(
math
.
ceil
(
context_length
/
max_position_embeddings
))
return
cls
.
from_lora_tensors
(
return
cls
.
from_lora_tensors
(
lora_model_id
=
get_lora_id
()
lora_model_id
=
get_lora_id
()
if
lora_model_id
is
None
else
lora_model_id
,
if
lora_model_id
is
None
else
lora_model_id
,
...
@@ -253,6 +352,7 @@ class LoRAModel:
...
@@ -253,6 +352,7 @@ class LoRAModel:
dtype
=
dtype
,
dtype
=
dtype
,
embeddings
=
embeddings
,
embeddings
=
embeddings
,
target_embedding_padding
=
target_embedding_padding
,
target_embedding_padding
=
target_embedding_padding
,
scaling_factor
=
scaling_factor
,
embedding_modules
=
embedding_modules
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embedding_padding_modules
,
embedding_padding_modules
=
embedding_padding_modules
,
)
)
...
@@ -286,6 +386,7 @@ class LoRAModelManager:
...
@@ -286,6 +386,7 @@ class LoRAModelManager:
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
self
.
lora_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
lora_index_to_id
:
List
[
Optional
[
int
]]
=
[
None
]
*
self
.
lora_slots
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
base_indices
=
torch
.
empty
(
self
.
max_num_batched_tokens
,
self
.
base_indices
=
torch
.
empty
(
self
.
max_num_batched_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
...
@@ -299,6 +400,12 @@ class LoRAModelManager:
...
@@ -299,6 +400,12 @@ class LoRAModelManager:
self
.
max_num_batched_tokens
,
self
.
max_num_batched_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
self
.
long_lora_indices
=
torch
.
empty
(
self
.
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
# 4 is the number of indicies tensors defined above
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
# embeddings_indices
...
@@ -308,6 +415,10 @@ class LoRAModelManager:
...
@@ -308,6 +415,10 @@ class LoRAModelManager:
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
self
.
supported_lora_modules
=
copy
.
deepcopy
(
self
.
supported_lora_modules
=
copy
.
deepcopy
(
self
.
model
.
supported_lora_modules
)
self
.
model
.
supported_lora_modules
)
if
lora_config
.
long_lora_scaling_factors
:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self
.
supported_lora_modules
.
append
(
"rotary_emb"
)
self
.
packed_modules_mapping
=
copy
.
deepcopy
(
self
.
packed_modules_mapping
=
copy
.
deepcopy
(
self
.
model
.
packed_modules_mapping
)
self
.
model
.
packed_modules_mapping
)
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
...
@@ -373,12 +484,32 @@ class LoRAModelManager:
...
@@ -373,12 +484,32 @@ class LoRAModelManager:
return
True
return
True
return
False
return
False
def
_set_long_lora_context
(
self
,
lora
:
LoRAModel
):
if
self
.
long_lora_context
is
None
:
return
if
lora
.
scaling_factor
is
None
:
return
if
(
lora
.
scaling_factor
not
in
self
.
scaling_factor_to_offset
):
raise
ValueError
(
f
"Long LoRA scaling factor
{
lora
.
scaling_factor
}
"
" has not been initialized."
)
offsets
=
self
.
scaling_factor_to_offset
.
get
(
lora
.
scaling_factor
)
if
offsets
:
self
.
long_lora_context
.
offsets_by_lora_id
[
lora
.
id
]
=
offsets
def
_add_lora
(
self
,
lora
:
LoRAModel
):
def
_add_lora
(
self
,
lora
:
LoRAModel
):
self
.
_create_merged_loras_inplace
(
lora
)
self
.
_create_merged_loras_inplace
(
lora
)
self
.
_registered_loras
[
lora
.
id
]
=
lora
self
.
_registered_loras
[
lora
.
id
]
=
lora
self
.
_set_long_lora_context
(
lora
)
def
add_lora
(
self
,
lora
:
LoRAModel
)
->
bool
:
def
add_lora
(
self
,
lora
:
LoRAModel
)
->
bool
:
"""Add a LoRAModel to the manager CPU cache."""
"""Add a LoRAModel to the manager CPU cache."""
logger
.
debug
(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s"
,
lora
.
id
,
lora
.
id
,
lora
.
scaling_factor
)
if
lora
.
id
not
in
self
.
_registered_loras
:
if
lora
.
id
not
in
self
.
_registered_loras
:
if
len
(
self
.
_registered_loras
)
>=
self
.
capacity
:
if
len
(
self
.
_registered_loras
)
>=
self
.
capacity
:
raise
RuntimeError
(
"No free LoRA slots."
)
raise
RuntimeError
(
"No free LoRA slots."
)
...
@@ -390,15 +521,18 @@ class LoRAModelManager:
...
@@ -390,15 +521,18 @@ class LoRAModelManager:
"""Remove a LoRAModel from the manager CPU cache."""
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
# TODO: should we check active lora?
self
.
deactivate_lora
(
lora_id
)
self
.
deactivate_lora
(
lora_id
)
if
self
.
long_lora_context
:
self
.
long_lora_context
.
offsets_by_lora_id
.
pop
(
lora_id
,
None
)
return
bool
(
self
.
_registered_loras
.
pop
(
lora_id
,
None
))
return
bool
(
self
.
_registered_loras
.
pop
(
lora_id
,
None
))
# TODO see if this can be vectorized
# TODO see if this can be vectorized
def
_set_lora_mapping
(
self
,
mapping
:
LoRAMapping
)
->
None
:
def
_set_lora_mapping
(
self
,
mapping
:
LoRAMapping
)
->
None
:
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
embeddings_indices
,
long_lora_offsets_tensor
,
indices_len
)
=
convert_mapping
(
mapping
,
self
.
lora_index_to_id
,
indices_len
)
=
convert_mapping
(
mapping
,
self
.
lora_index_to_id
,
self
.
lora_slots
+
1
,
self
.
vocab_size
,
self
.
lora_slots
+
1
,
self
.
vocab_size
,
self
.
lora_config
.
lora_extra_vocab_size
)
self
.
lora_config
.
lora_extra_vocab_size
,
self
.
long_lora_context
)
self
.
base_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
self
.
base_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
self
.
sampler_indices
[:
sampler_indices
.
shape
[
0
]].
copy_
(
sampler_indices
)
self
.
sampler_indices
[:
sampler_indices
.
shape
[
0
]].
copy_
(
sampler_indices
)
self
.
sampler_indices_padded
[:
sampler_indices_padded
.
shape
[
0
]].
copy_
(
self
.
sampler_indices_padded
[:
sampler_indices_padded
.
shape
[
0
]].
copy_
(
...
@@ -406,6 +540,11 @@ class LoRAModelManager:
...
@@ -406,6 +540,11 @@ class LoRAModelManager:
self
.
embeddings_indices
[:
embeddings_indices
.
self
.
embeddings_indices
[:
embeddings_indices
.
shape
[
0
],
:
embeddings_indices
.
shape
[
1
]].
copy_
(
shape
[
0
],
:
embeddings_indices
.
shape
[
1
]].
copy_
(
embeddings_indices
)
embeddings_indices
)
if
long_lora_offsets_tensor
is
not
None
:
self
.
long_lora_indices
[:
long_lora_offsets_tensor
.
shape
[
0
]].
copy_
(
long_lora_offsets_tensor
)
else
:
self
.
long_lora_indices
.
zero_
()
# Maintain the reference
# Maintain the reference
self
.
indices_len
[:]
=
indices_len
self
.
indices_len
[:]
=
indices_len
...
@@ -428,7 +567,8 @@ class LoRAModelManager:
...
@@ -428,7 +567,8 @@ class LoRAModelManager:
self
.
_active_loras
.
clear
()
self
.
_active_loras
.
clear
()
def
_create_lora_modules
(
self
):
def
_create_lora_modules
(
self
):
for
module_name
,
module
in
self
.
model
.
named_modules
():
for
module_name
,
module
in
self
.
model
.
named_modules
(
remove_duplicate
=
False
):
if
not
self
.
_match_target_modules
(
module_name
):
if
not
self
.
_match_target_modules
(
module_name
):
continue
continue
parts
=
module_name
.
split
(
"."
)[
-
1
]
parts
=
module_name
.
split
(
"."
)[
-
1
]
...
@@ -437,6 +577,13 @@ class LoRAModelManager:
...
@@ -437,6 +577,13 @@ class LoRAModelManager:
self
.
model
,
module_name
,
self
.
model
,
module_name
,
from_layer
(
module
,
self
.
lora_slots
,
self
.
lora_config
,
from_layer
(
module
,
self
.
lora_slots
,
self
.
lora_config
,
packed_moduled_lst
,
self
.
model
.
config
))
packed_moduled_lst
,
self
.
model
.
config
))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if
isinstance
(
new_module
,
LinearScalingRotaryEmbeddingWithLora
):
self
.
long_lora_context
=
LongContextLoRAContext
(
new_module
.
scaling_factors
,
new_module
.
rotary_dim
)
self
.
scaling_factor_to_offset
=
\
new_module
.
scaling_factor_to_offset
# (yard1): TODO make this more robust
# (yard1): TODO make this more robust
if
"lm_head"
in
module_name
:
if
"lm_head"
in
module_name
:
logits_processor_module
=
self
.
model
.
get_submodule
(
logits_processor_module
=
self
.
model
.
get_submodule
(
...
@@ -451,7 +598,8 @@ class LoRAModelManager:
...
@@ -451,7 +598,8 @@ class LoRAModelManager:
self
.
_register_packed_modules
(
module_name
)
self
.
_register_packed_modules
(
module_name
)
new_module
.
set_mapping
(
self
.
base_indices
,
self
.
sampler_indices
,
new_module
.
set_mapping
(
self
.
base_indices
,
self
.
sampler_indices
,
self
.
sampler_indices_padded
,
self
.
sampler_indices_padded
,
self
.
embeddings_indices
,
self
.
indices_len
)
self
.
embeddings_indices
,
self
.
long_lora_indices
,
self
.
indices_len
)
def
register_module
(
self
,
module_name
:
str
,
module
:
"BaseLayerWithLoRA"
):
def
register_module
(
self
,
module_name
:
str
,
module
:
"BaseLayerWithLoRA"
):
assert
isinstance
(
module
,
BaseLayerWithLoRA
)
assert
isinstance
(
module
,
BaseLayerWithLoRA
)
...
@@ -461,12 +609,14 @@ class LoRAModelManager:
...
@@ -461,12 +609,14 @@ class LoRAModelManager:
self
,
self
,
lora_id
:
int
,
lora_id
:
int
,
rank
:
int
,
rank
:
int
,
scaling_factor
:
Optional
[
float
],
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
)
->
LoRAModel
:
embedding_modules
:
Optional
[
Dict
[
str
,
str
]]
=
None
)
->
LoRAModel
:
"""Create zero-initialized LoRAModel for warmup."""
"""Create zero-initialized LoRAModel for warmup."""
model
=
LoRAModel
(
lora_id
,
rank
,
{})
model
=
LoRAModel
(
lora_id
,
rank
,
{}
,
scaling_factor
)
for
module_name
,
module
in
self
.
model
.
named_modules
():
for
module_name
,
module
in
self
.
model
.
named_modules
():
if
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
if
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
module
,
BaseLayerWithLoRA
):
module
,
BaseLayerWithLoRA
)
or
isinstance
(
module
,
LinearScalingRotaryEmbeddingWithLora
):
continue
continue
parts
=
module_name
.
split
(
"."
)
parts
=
module_name
.
split
(
"."
)
if
module_name
not
in
self
.
packed_modules
:
if
module_name
not
in
self
.
packed_modules
:
...
@@ -596,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
...
@@ -596,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def
add_lora
(
self
,
lora
:
LoRAModel
)
->
bool
:
def
add_lora
(
self
,
lora
:
LoRAModel
)
->
bool
:
"""Add a LoRAModel to the manager."""
"""Add a LoRAModel to the manager."""
logger
.
debug
(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s"
,
lora
.
id
,
lora
.
id
,
lora
.
scaling_factor
)
if
lora
.
id
not
in
self
.
_registered_loras
:
if
lora
.
id
not
in
self
.
_registered_loras
:
self
.
_add_lora
(
lora
)
self
.
_add_lora
(
lora
)
was_added
=
True
was_added
=
True
...
...
vllm/lora/request.py
View file @
b9e12416
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
@
dataclass
@
dataclass
...
@@ -18,6 +19,7 @@ class LoRARequest:
...
@@ -18,6 +19,7 @@ class LoRARequest:
lora_name
:
str
lora_name
:
str
lora_int_id
:
int
lora_int_id
:
int
lora_local_path
:
str
lora_local_path
:
str
long_lora_max_len
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
lora_int_id
<
1
:
if
self
.
lora_int_id
<
1
:
...
...
vllm/lora/utils.py
View file @
b9e12416
...
@@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import (
...
@@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
LinearScalingRotaryEmbeddingWithLora
,
LogitsProcessorWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
MergedQKVParallelLinearWithLora
,
...
@@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...
@@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
VocabParallelEmbeddingWithLoRA
,
ColumnParallelLinearWithLoRA
,
VocabParallelEmbeddingWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
QKVParallelLinearWithLora
,
ColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
LogitsProcessorWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
QKVParallelLinearWithLora
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
LogitsProcessorWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
,
LinearScalingRotaryEmbeddingWithLora
,
}
}
...
...
vllm/lora/worker_manager.py
View file @
b9e12416
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
typing
import
Any
,
Dict
,
List
,
Set
,
Type
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Set
,
Type
,
Union
import
torch
import
torch
...
@@ -16,15 +17,31 @@ logger = init_logger(__name__)
...
@@ -16,15 +17,31 @@ logger = init_logger(__name__)
class
AbstractWorkerLoRAManager
(
ABC
):
class
AbstractWorkerLoRAManager
(
ABC
):
"""Abstract class for managing LoRA models on the worker side."""
"""Abstract class for managing LoRA models on the worker side."""
def
__init__
(
self
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
def
__init__
(
self
,
vocab_size
:
int
,
lora_config
:
LoRAConfig
,
max_num_seqs
:
int
,
device
:
torch
.
device
):
max_num_batched_tokens
:
int
,
vocab_size
:
int
,
lora_config
:
LoRAConfig
,
device
:
torch
.
device
,
max_position_embeddings
:
Optional
[
int
]
=
None
):
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_position_embeddings
=
max_position_embeddings
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
device
=
device
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
# If False, do not cache. If None, cache is empty.
self
.
_cached_dummy_lora
:
Union
[
None
,
Literal
[
False
],
LoRAModel
]
=
False
@
contextmanager
def
dummy_lora_cache
(
self
):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self
.
_cached_dummy_lora
=
None
yield
self
.
_cached_dummy_lora
=
False
@
abstractproperty
@
abstractproperty
def
is_enabled
(
self
)
->
bool
:
def
is_enabled
(
self
)
->
bool
:
...
...
...
@@ -80,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -80,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_modules
:
Dict
[
str
,
str
],
embedding_modules
:
Dict
[
str
,
str
],
embedding_padding_modules
:
List
[
str
],
embedding_padding_modules
:
List
[
str
],
lora_model_cls
:
Type
[
LoRAModel
]
=
LoRAModel
,
lora_model_cls
:
Type
[
LoRAModel
]
=
LoRAModel
,
max_position_embeddings
:
Optional
[
int
]
=
None
,
):
):
self
.
_lora_model_cls
=
lora_model_cls
self
.
_lora_model_cls
=
lora_model_cls
self
.
embedding_modules
=
embedding_modules
self
.
embedding_modules
=
embedding_modules
self
.
embedding_padding_modules
=
embedding_padding_modules
self
.
embedding_padding_modules
=
embedding_padding_modules
# Lazily initialized by create_lora_manager.
# Lazily initialized by create_lora_manager.
self
.
_lora_manager
:
LoRAModelManager
self
.
_lora_manager
:
LoRAModelManager
super
().
__init__
(
max_num_seqs
,
max_num_batched_tokens
,
vocab_size
,
super
().
__init__
(
lora_config
,
device
)
max_num_seqs
,
max_num_batched_tokens
,
vocab_size
,
lora_config
,
device
,
max_position_embeddings
=
max_position_embeddings
,
)
@
property
@
property
def
is_enabled
(
self
)
->
bool
:
def
is_enabled
(
self
)
->
bool
:
...
@@ -150,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -150,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_request
.
lora_local_path
,
lora_request
.
lora_local_path
,
expected_lora_modules
,
expected_lora_modules
,
max_position_embeddings
=
self
.
max_position_embeddings
,
lora_model_id
=
lora_request
.
lora_int_id
,
lora_model_id
=
lora_request
.
lora_int_id
,
device
=
"cpu"
,
device
=
"cpu"
,
dtype
=
self
.
lora_config
.
lora_dtype
,
dtype
=
self
.
lora_config
.
lora_dtype
,
...
@@ -174,9 +199,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -174,9 +199,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def
add_dummy_lora
(
self
,
lora_request
:
LoRARequest
,
rank
:
int
)
->
bool
:
def
add_dummy_lora
(
self
,
lora_request
:
LoRARequest
,
rank
:
int
)
->
bool
:
if
lora_request
.
lora_int_id
in
self
.
list_loras
():
if
lora_request
.
lora_int_id
in
self
.
list_loras
():
return
False
return
False
return
self
.
_lora_manager
.
add_lora
(
if
isinstance
(
self
.
_cached_dummy_lora
,
LoRAModel
):
self
.
_lora_manager
.
create_dummy_lora
(
lora_request
.
lora_int_id
,
dummy_lora
=
self
.
_cached_dummy_lora
.
clone
(
rank
,
self
.
embedding_modules
))
lora_request
.
lora_int_id
)
else
:
dummy_lora
=
self
.
_lora_manager
.
create_dummy_lora
(
lora_request
.
lora_int_id
,
rank
,
1
,
self
.
embedding_modules
)
if
self
.
_cached_dummy_lora
is
None
:
self
.
_cached_dummy_lora
=
dummy_lora
return
self
.
_lora_manager
.
add_lora
(
dummy_lora
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
if
lora_request
.
lora_int_id
in
self
.
list_loras
():
if
lora_request
.
lora_int_id
in
self
.
list_loras
():
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
b9e12416
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_
moe
,
get_config_file_name
)
fused_
experts
,
fused_moe
,
fused_topk
,
get_config_file_name
)
__all__
=
[
__all__
=
[
"fused_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_experts"
,
"get_config_file_name"
,
"get_config_file_name"
,
]
]
vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
0 → 100644
View file @
b9e12416
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
64
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
64
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
0 → 100644
View file @
b9e12416
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
},
"48"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
},
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
0 → 100644
View file @
b9e12416
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_stages"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
0 → 100644
View file @
b9e12416
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_stages"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
1
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
b9e12416
...
@@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int,
...
@@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int,
return
None
return
None
def
fused_
moe
(
def
fused_
topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
):
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
# The MoE kernels are not yet supported on ROCm.
...
@@ -393,6 +349,33 @@ def fused_moe(
...
@@ -393,6 +349,33 @@ def fused_moe(
del
token_expert_indicies
# Not used. Will be used in the future.
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
override_config
:
if
override_config
:
config
=
override_config
config
=
override_config
...
@@ -477,3 +460,63 @@ def fused_moe(
...
@@ -477,3 +460,63 @@ def fused_moe(
out
=
hidden_states
)
out
=
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
dim
=
1
)
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
vllm/model_executor/layers/linear.py
View file @
b9e12416
...
@@ -59,7 +59,6 @@ class LinearMethodBase(QuantizeMethodBase):
...
@@ -59,7 +59,6 @@ class LinearMethodBase(QuantizeMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -81,8 +80,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -81,8 +80,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
**
extra_weight_attrs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
),
dtype
=
params_dtype
),
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -161,15 +159,13 @@ class ReplicatedLinear(LinearBase):
...
@@ -161,15 +159,13 @@ class ReplicatedLinear(LinearBase):
quant_config: Quantization configure.
quant_config: Quantization configure.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
...
@@ -222,17 +218,15 @@ class ColumnParallelLinear(LinearBase):
...
@@ -222,17 +218,15 @@ class ColumnParallelLinear(LinearBase):
the list would be size 3.
the list would be size 3.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
):
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
...
@@ -240,18 +234,26 @@ class ColumnParallelLinear(LinearBase):
...
@@ -240,18 +234,26 @@ class ColumnParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
# If QKV or MergedColumn, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
for
output_size
in
self
.
output_sizes
]
if
output_sizes
is
None
:
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
output_sizes
=
[
output_size
]
# All the linear layer supports quant method.
self
.
quant_method
.
create_weights
(
assert
self
.
quant_method
is
not
None
layer
=
self
,
self
.
quant_method
.
create_weights
(
self
,
input_size_per_partition
=
self
.
input_size
,
self
.
input_size
,
output_partition_sizes
=
self
.
output_partition_sizes
,
[
x
//
tp_size
for
x
in
output_sizes
],
input_size
=
self
.
input_size
,
self
.
input_size
,
output_size
=
self
.
output_size
,
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
torch
.
empty
(
self
.
output_size_per_partition
,
...
@@ -333,24 +335,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -333,24 +335,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
quant_config: Quantization configure.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_sizes
:
List
[
int
],
output_sizes
:
List
[
int
],
bias
:
bool
=
True
,
bias
:
bool
=
True
,
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
super
().
__init__
(
input_size
=
input_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
output_size
=
sum
(
output_sizes
),
self
.
output_sizes
)
bias
=
bias
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
...
@@ -360,6 +365,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -360,6 +365,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
param_shard_splitter
=
getattr
(
param
,
"shard_splitter"
,
None
)
if
output_dim
is
not
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if
loaded_shard_id
is
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
None
)
...
@@ -424,6 +449,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -424,6 +449,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size
=
loaded_weight
.
shape
[
0
]
shard_size
=
loaded_weight
.
shape
[
0
]
shard_offset
=
loaded_shard_id
*
shard_size
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif
param_shard_splitter
is
not
None
:
logical_widths
=
getattr
(
param
,
"logical_widths"
,
None
)
param_data
,
loaded_weight
=
param_shard_splitter
(
param_data
,
loaded_weight
,
loaded_shard_id
,
logical_widths
)
# Special case for Fp8 scales.
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
...
@@ -436,7 +468,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -436,7 +468,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
"the same for all partitions."
)
if
fp8_scales_shard_indexer
is
None
:
if
len
(
param_data
.
shape
)
==
0
:
param_data
=
param_data
.
reshape
(
1
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
assert
param_data_
.
shape
==
loaded_weight
.
shape
param_data_
.
copy_
(
loaded_weight
)
param_data_
.
copy_
(
loaded_weight
)
...
@@ -448,6 +487,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -448,6 +487,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
"""Linear layers for the attention's QKV transformation.
...
@@ -472,17 +512,15 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -472,17 +512,15 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
quant_config: Quantization configure.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
head_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
total_num_kv_heads
:
Optional
[
int
]
=
None
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
self
.
total_num_heads
=
total_num_heads
...
@@ -502,14 +540,18 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -502,14 +540,18 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size
=
self
.
hidden_size
input_size
=
self
.
hidden_size
output_size
=
(
self
.
num_heads
+
output_size
=
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
output_sizes
=
[
self
.
output_sizes
=
[
self
.
num_heads
*
tp_size
*
self
.
head_size
,
self
.
num_heads
*
self
.
head_size
*
tp_size
,
# q_proj
self
.
num_kv_heads
*
tp_size
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# k_proj
self
.
num_kv_heads
*
tp_size
*
self
.
head_size
self
.
num_kv_heads
*
self
.
head_size
*
tp_size
,
# v_proj
]
]
super
().
__init__
(
input_size
=
input_size
,
super
().
__init__
(
input_size
,
output_size
,
bias
,
False
,
skip_bias_add
,
output_size
=
output_size
,
params_dtype
,
quant_config
,
output_sizes
)
bias
=
bias
,
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
def
weight_loader
(
self
,
...
@@ -520,6 +562,26 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -520,6 +562,26 @@ class QKVParallelLinear(ColumnParallelLinear):
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
param_shard_splitter
=
getattr
(
param
,
"shard_splitter"
,
None
)
if
output_dim
is
not
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if
loaded_shard_id
is
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
None
)
...
@@ -558,6 +620,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -558,6 +620,8 @@ class QKVParallelLinear(ColumnParallelLinear):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
# If output dim is defined, use the default loading process.
if
output_dim
is
not
None
:
if
output_dim
is
not
None
:
if
loaded_shard_id
==
"q"
:
if
loaded_shard_id
==
"q"
:
shard_offset
=
0
shard_offset
=
0
...
@@ -601,6 +665,12 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -601,6 +665,12 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
shard_size
)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif
param_shard_splitter
is
not
None
:
logical_widths
=
getattr
(
param
,
"logical_widths"
,
None
)
param_data
,
loaded_weight
=
param_shard_splitter
(
param_data
,
loaded_weight
,
loaded_shard_id
,
logical_widths
)
# Special case for Fp8 scales.
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
...
@@ -612,6 +682,11 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -612,6 +682,11 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
"for all partitions."
)
if
len
(
param_data
.
shape
)
==
0
:
param_data
=
param_data
.
reshape
(
1
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
assert
param_data_
.
shape
==
loaded_weight
.
shape
...
@@ -650,17 +725,15 @@ class RowParallelLinear(LinearBase):
...
@@ -650,17 +725,15 @@ class RowParallelLinear(LinearBase):
quant_config: Quantization configure.
quant_config: Quantization configure.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
input_is_parallel
:
bool
=
True
,
input_is_parallel
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
...
@@ -670,16 +743,15 @@ class RowParallelLinear(LinearBase):
...
@@ -670,16 +743,15 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
quant_method
.
create_weights
(
self
.
input_size_per_partition
,
layer
=
self
,
[
self
.
out
put_size
]
,
input_size_per_partition
=
self
.
in
put_size
_per_partition
,
self
.
in
put_size
,
output_partition_sizes
=
[
self
.
out
put_size
]
,
self
.
out
put_size
,
input_size
=
self
.
in
put_size
,
self
.
params_dtyp
e
,
output_size
=
self
.
output_siz
e
,
weight_loader
=
self
.
weight_loader
)
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
"results can lead to incorrect results"
)
...
@@ -708,12 +780,16 @@ class RowParallelLinear(LinearBase):
...
@@ -708,12 +780,16 @@ class RowParallelLinear(LinearBase):
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for Fp8 scales.
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_weight
,
shard_id
=
0
)
shard_id
=
0
)
if
fp8_scales_shard_indexer
is
None
and
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
...
...
vllm/model_executor/layers/logits_processor.py
View file @
b9e12416
"""A layer that compute logits from hidden_stats."""
"""A layer that compute logits from hidden_stats."""
import
inspect
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -95,15 +96,25 @@ def _apply_logits_processors(
...
@@ -95,15 +96,25 @@ def _apply_logits_processors(
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
logits_processors
=
sampling_params
.
logits_processors
logits_processors
=
sampling_params
.
logits_processors
if
logits_processors
:
if
logits_processors
:
found_logits_processors
=
True
found_logits_processors
=
True
for
seq_id
,
logits_row_idx
in
zip
(
seq_ids
,
for
seq_id
,
logits_row_idx
in
zip
(
seq_ids
,
seq_group
.
sample_indices
):
seq_group
.
sample_indices
):
logits_row
=
logits
[
logits_row_idx
]
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
seq_group
.
seq_data
[
seq_id
].
output_token_ids
past_tokens_ids
=
seq_group
.
seq_data
[
seq_id
].
output_token_ids
prompt_tokens_ids
=
seq_group
.
seq_data
[
seq_id
].
prompt_token_ids
for
logits_processor
in
logits_processors
:
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
parameters
=
inspect
.
signature
(
logits_processor
).
parameters
if
len
(
parameters
)
==
3
:
logits_row
=
logits_processor
(
prompt_tokens_ids
,
past_tokens_ids
,
logits_row
)
else
:
logits_row
=
logits_processor
(
past_tokens_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits
[
logits_row_idx
]
=
logits_row
logits_processed
+=
len
(
seq_group
.
sample_indices
)
+
len
(
logits_processed
+=
len
(
seq_group
.
sample_indices
)
+
len
(
...
...
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment