Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6640dc0b
Commit
6640dc0b
authored
Jun 20, 2024
by
zhuwenwen
Browse files
Merge branch 'main' of
http://10.6.10.68/dcutoolkit/deeplearing/vllm
parents
44d4d334
83e4e0fe
Changes
110
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1116 additions
and
76 deletions
+1116
-76
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+2
-2
vllm/envs.py
vllm/envs.py
+6
-0
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+3
-3
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+101
-0
vllm/inputs.py
vllm/inputs.py
+1
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+3
-1
vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json
...onfigs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json
...onfigs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json
...configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json
...configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+36
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py
...on/compressed_tensors/schemes/compressed_tensors_w4a16.py
+168
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
...d_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
+2
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
...d_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+2
-2
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+4
-2
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+74
-3
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+38
-7
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+75
-32
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+16
-13
No files found.
vllm/entrypoints/openai/run_batch.py
View file @
6640dc0b
...
...
@@ -5,7 +5,6 @@ from io import StringIO
import
aiohttp
import
vllm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
BatchRequestInput
,
...
...
@@ -15,6 +14,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -135,7 +135,7 @@ async def main(args):
if
__name__
==
"__main__"
:
args
=
parse_args
()
logger
.
info
(
"vLLM API server version %s"
,
vllm
.
__version__
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
asyncio
.
run
(
main
(
args
))
vllm/envs.py
View file @
6640dc0b
...
...
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
...
@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
}
# end-env-vars-definition
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
6640dc0b
...
...
@@ -9,7 +9,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
cuda_device_count_stateless
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
logger
=
init_logger
(
__name__
)
...
...
@@ -33,8 +34,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Disable torch async compiling which won't work with daemonic processes
os
.
environ
[
"TORCHINDUCTOR_COMPILE_THREADS"
]
=
"1"
from
torch.cuda
import
device_count
assert
world_size
<=
device_count
(),
(
assert
world_size
<=
cuda_device_count_stateless
(),
(
"please set tensor_parallel_size to less than max local gpu count"
)
distributed_init_method
=
get_distributed_init_method
(
...
...
vllm/executor/tpu_executor.py
0 → 100644
View file @
6640dc0b
from
typing
import
List
,
Set
,
Tuple
import
torch
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
logger
=
init_logger
(
__name__
)
class
TPUExecutor
(
ExecutorBase
):
def
_init_executor
(
self
)
->
None
:
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
"Chunked prefill is not yet supported for TPU backend"
)
assert
not
self
.
speculative_config
,
(
"Speculative decoding is not yet supported for TPU backend"
)
if
self
.
model_config
.
dtype
in
(
torch
.
float16
,
torch
.
float32
):
logger
.
warning
(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead."
,
self
.
model_config
.
dtype
)
self
.
model_config
.
dtype
=
torch
.
bfloat16
# Instantiate the worker and load the model to the device.
self
.
_init_worker
()
def
_init_worker
(
self
):
from
vllm.worker.tpu_worker
import
TPUWorker
assert
self
.
parallel_config
.
world_size
==
1
,
(
"TPUExecutor currently only supports a single TPU chip."
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
TPUWorker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
self
.
cache_config
,
self
.
load_config
,
self
.
vision_language_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger
.
info
(
"# TPU blocks: %d, # CPU blocks: %d"
,
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return
self
.
driver_worker
.
determine_num_available_blocks
()
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
(
"LoRA is not implemented for TPU backend."
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRA is not implemented for TPU backend."
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
(
"LoRA is not implemented for TPU backend."
)
def
check_health
(
self
)
->
None
:
# TPUExecutor will always be healthy as long as it's running.
return
class
TPUExecutorAsync
(
TPUExecutor
,
ExecutorAsyncBase
):
async
def
execute_model_async
(
self
,
sexecute_model_req
:
ExecuteModelRequest
,
)
->
SamplerOutput
:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
sexecute_model_req
)
return
output
vllm/inputs.py
View file @
6640dc0b
...
...
@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
from
typing_extensions
import
NotRequired
if
TYPE_CHECKING
:
from
vllm.
sequence
import
MultiModalData
from
vllm.
multimodal
import
MultiModalData
class
ParsedText
(
TypedDict
):
...
...
vllm/model_executor/custom_op.py
View file @
6640dc0b
import
torch.nn
as
nn
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
,
is_tpu
class
CustomOp
(
nn
.
Module
):
...
...
@@ -56,5 +56,7 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
elif
is_cpu
():
return
self
.
forward_cpu
elif
is_tpu
():
return
self
.
forward_tpu
else
:
return
self
.
forward_cuda
vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
6640dc0b
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
5
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
}
}
vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
6640dc0b
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
6640dc0b
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
}
}
vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
6640dc0b
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
6640dc0b
...
...
@@ -7,8 +7,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsW
8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
)
CompressedTensorsScheme
,
CompressedTensorsW
4A16
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
...
...
@@ -47,16 +47,27 @@ class CompressedTensorsConfig(QuantizationConfig):
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for
key
,
quant_config
in
config
[
"config_groups"
].
items
():
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
layer_quant_details
[
target
]
=
{}
layer_quant_details
[
target
][
"weight"
]
=
QuantizationArgs
.
parse_obj
(
"weight
s
"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"weights"
))
layer_quant_details
[
target
][
"input"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
try
:
layer_quant_details
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
except
Exception
:
layer_quant_details
[
target
][
"input_activations"
]
=
None
return
cls
(
layer_quant_details
=
layer_quant_details
,
ignore
=
ignore
)
...
...
@@ -86,8 +97,23 @@ class CompressedTensorsConfig(QuantizationConfig):
return
is_8_bits
and
is_token_tensor
and
is_symmetric
and
is_dynamic
def
_is_w4a16
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant_none
=
input_quant
is
None
is_4_bits
=
weight_quant
.
num_bits
==
4
is_symmetric
=
weight_quant
.
symmetric
is_static
=
not
weight_quant
.
dynamic
return
is_4_bits
and
input_quant_none
and
is_symmetric
and
is_static
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
if
self
.
_is_w4a16
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A16
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8StaticTensor
()
...
...
@@ -113,8 +139,9 @@ class CompressedTensorsConfig(QuantizationConfig):
raise
ValueError
(
f
"Could not find quantization details for
{
layer
}
."
)
return
self
.
_get_schema
(
weight_quant
=
layer_quant_details
[
"weight"
],
input_quant
=
layer_quant_details
[
"input"
])
return
self
.
_get_schema
(
weight_quant
=
layer_quant_details
[
"weights"
],
input_quant
=
layer_quant_details
[
"input_activations"
])
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
...
...
@@ -140,6 +167,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer
=
layer
,
input_size_per_partition
=
input_size_per_partition
,
output_partition_sizes
=
output_partition_sizes
,
input_size
=
input_size
,
output_size
=
output_size
,
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
6640dc0b
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
# noqa: F401
from
.compressed_tensors_unquantized
import
(
# noqa: F401
CompressedTensorsUnquantized
)
from
.compressed_tensors_w4a16
import
CompressedTensorsW4A16
# noqa: F401
from
.compressed_tensors_w8a8_dynamictoken
import
(
# noqa: F401, E501
CompressedTensorsW8A8DynamicToken
)
from
.compressed_tensors_w8a8_statictensor
import
(
# noqa: F401, E501
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py
0 → 100644
View file @
6640dc0b
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW4A16"
]
class
CompressedTensorsW4A16
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
self
.
num_bits
=
num_bits
self
.
strategy
=
strategy
self
.
group_size
=
group_size
if
self
.
strategy
==
"group"
and
self
.
group_size
is
None
:
raise
ValueError
(
"group_size must be given when using strategy group"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
pack_factor
=
32
//
self
.
num_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
self
.
group_size
is
not
None
:
group_size
=
self
.
group_size
else
:
group_size
=
input_size
weight_scale_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
if
(
input_size
!=
input_size_per_partition
and
self
.
group_size
is
not
None
):
weight_scale_dim
=
1
scales_and_zp_size
=
input_size_per_partition
//
group_size
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
pack_factor
})
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
layer
.
register_parameter
(
"weight_packed"
,
weight
)
weight_scale
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
})
set_weight_attrs
(
weight_scale
,
{
"input_dim"
:
weight_scale_dim
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape
=
Parameter
(
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
set_weight_attrs
(
weight_shape
,
{
"weight_loader"
:
weight_loader
})
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
layer
.
is_k_full
=
True
layer
.
group_size
=
group_size
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
weight_packed
.
device
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
.
t
().
contiguous
(),
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
num_bits
)
replace_tensor
(
"weight_packed"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
scales_size_k
,
scales_size_n
,
layer
.
group_size
,
self
.
num_bits
)
replace_tensor
(
"weight_scale"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
weight_packed
,
layer
.
weight_scale
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
num_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
)
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
View file @
6640dc0b
...
...
@@ -81,5 +81,5 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
weight_scale
=
layer
.
weight_scale
x_q
,
input_scales
=
custom_ops
.
scaled_int8_quant
(
x
)
return
custom_ops
.
cutlass_scaled_mm
_dq
(
x_q
,
weight
.
t
(),
input_scales
,
weight_scale
,
x
.
dtype
)
return
custom_ops
.
cutlass_scaled_mm
(
x_q
,
weight
.
t
(),
input_scales
,
weight_scale
,
x
.
dtype
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
View file @
6640dc0b
...
...
@@ -99,5 +99,5 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
# Input quantize
x_q
,
_
=
custom_ops
.
scaled_int8_quant
(
x
,
act_scale
)
return
custom_ops
.
cutlass_scaled_mm
_dq
(
x_q
,
weight
.
t
(),
act_scale
,
weight_scale
,
x
.
dtype
)
return
custom_ops
.
cutlass_scaled_mm
(
x_q
,
weight
.
t
(),
act_scale
,
weight_scale
,
x
.
dtype
)
vllm/model_executor/layers/quantization/fp8.py
View file @
6640dc0b
...
...
@@ -257,11 +257,13 @@ class Fp8LinearMethod(LinearMethodBase):
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
# Temporarily disable CUTLASS kernels due to an illegal memory access
#if bias is None and self.cutlass_fp8_supported:
if
False
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
_dq
(
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
6640dc0b
...
...
@@ -28,6 +28,7 @@ import torch
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.utils
import
is_tpu
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return
x
.
flatten
(
-
2
)
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x_
=
torch
.
view_as_complex
(
torch
.
stack
(
torch
.
chunk
(
x
.
transpose
(
1
,
2
).
float
(),
2
,
dim
=-
1
),
dim
=-
1
))
x_out
=
torch
.
view_as_real
(
x_
*
freqs_cis
).
type_as
(
x
)
x_out
=
torch
.
cat
(
torch
.
chunk
(
x_out
,
2
,
dim
=-
1
),
dim
=-
2
)
x_out
=
x_out
.
reshape
(
x_out
.
shape
[
0
],
x_out
.
shape
[
1
],
x_out
.
shape
[
2
],
-
1
).
transpose
(
1
,
2
)
return
x_out
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
...
...
@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
use_native2
=
is_tpu
()
and
is_neox_style
if
not
self
.
use_native2
:
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
else
:
cos
,
sin
=
cache
.
chunk
(
2
,
dim
=-
1
)
freqs_cis
=
cos
+
1j
*
sin
self
.
register_buffer
(
"freqs_cis"
,
freqs_cis
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
...
...
@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
"""A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
...
...
@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
key
=
key
.
flatten
(
-
2
)
return
query
,
key
def
forward_native2
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if
positions
.
dim
()
==
1
:
batch_size
=
1
seq_len
=
positions
.
shape
[
0
]
else
:
batch_size
,
seq_len
=
positions
.
shape
if
offsets
is
not
None
:
positions
=
positions
+
offsets
freqs_cis
=
self
.
freqs_cis
.
index_select
(
0
,
positions
.
flatten
())
freqs_cis
=
freqs_cis
.
view
(
batch_size
,
1
,
seq_len
,
-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
freqs_cis
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
freqs_cis
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_tpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_fn
=
(
self
.
forward_native2
if
self
.
use_native2
else
self
.
forward_native
)
return
forward_fn
(
positions
,
query
,
key
,
offsets
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
...
...
vllm/model_executor/model_loader/loader.py
View file @
6640dc0b
...
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_tensorized
,
load_with_tensorizer
,
tensorizer_weights_iterator
)
serialize_vllm_model
,
tensorizer_weights_iterator
)
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
...
...
@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_tpu
logger
=
init_logger
(
__name__
)
...
...
@@ -230,12 +231,26 @@ class DefaultModelLoader(BaseModelLoader):
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
return
np_cache_weights_iterator
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
)
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
return
pt_weights_iterator
(
hf_weights_files
)
weights_iterator
=
np_cache_weights_iterator
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
)
elif
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
if
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
def
_xla_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
xm
.
mark_step
()
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
return
weights_iterator
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
...
...
@@ -380,6 +395,12 @@ class TensorizerLoader(BaseModelLoader):
cache_config
:
CacheConfig
)
->
nn
.
Module
:
self
.
_verify_config
(
model_config
,
parallel_config
)
if
parallel_config
.
tensor_parallel_size
>
1
:
from
vllm.distributed
import
get_tensor_model_parallel_rank
self
.
tensorizer_config
.
tensorizer_uri
=
\
self
.
tensorizer_config
.
tensorizer_uri
\
%
get_tensor_model_parallel_rank
()
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
model_config
,
device_config
,
lora_config
,
...
...
@@ -390,6 +411,16 @@ class TensorizerLoader(BaseModelLoader):
vision_language_config
,
cache_config
)
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
serialize_vllm_model
(
model
=
model
,
tensorizer_config
=
tensorizer_config
,
)
class
ShardedStateLoader
(
BaseModelLoader
):
"""
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
6640dc0b
...
...
@@ -2,11 +2,11 @@ import argparse
import
dataclasses
import
io
import
os
import
re
import
time
import
typing
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Generator
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
BinaryIO
,
Generator
,
Optional
,
Tuple
,
Type
,
Union
import
torch
from
torch
import
nn
...
...
@@ -14,6 +14,7 @@ from transformers import PretrainedConfig
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -48,8 +49,7 @@ logger = init_logger(__name__)
@
dataclass
class
TensorizerConfig
:
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
typing
.
BinaryIO
,
str
,
bytes
,
os
.
PathLike
,
int
]
tensorizer_uri
:
str
vllm_tensorized
:
Optional
[
bool
]
=
False
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
None
...
...
@@ -60,6 +60,12 @@ class TensorizerConfig:
model_class
:
Optional
[
Type
[
torch
.
nn
.
Module
]]
=
None
hf_config
:
Optional
[
PretrainedConfig
]
=
None
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
_is_sharded
:
bool
=
False
def
__post_init__
(
self
):
# check if the configuration is for a sharded vLLM model
self
.
_is_sharded
=
isinstance
(
self
.
tensorizer_uri
,
str
)
\
and
re
.
search
(
r
'%0\dd'
,
self
.
tensorizer_uri
)
is
not
None
def
_construct_tensorizer_args
(
self
)
->
"TensorizerArgs"
:
tensorizer_args
=
{
...
...
@@ -78,13 +84,12 @@ class TensorizerConfig:
self
,
parallel_config
:
"ParallelConfig"
,
)
->
None
:
if
(
parallel_config
.
tensor_parallel_size
>
1
and
self
.
tensorizer_uri
is
not
None
)
:
if
parallel_config
.
tensor_parallel_size
>
1
\
and
not
self
.
_is_sharded
:
raise
ValueError
(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`."
)
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard"
)
def
verify_with_model_config
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
(
model_config
.
quantization
is
not
None
...
...
@@ -102,8 +107,8 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig,
@
dataclass
class
TensorizerArgs
:
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
typing
.
BinaryIO
,
str
,
bytes
,
os
.
PathLike
,
int
]
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
BinaryIO
,
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
Optional
[
bool
]
=
False
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
None
...
...
@@ -332,6 +337,7 @@ class TensorizerAgent:
)
as
stream
,
TensorDeserializer
(
stream
,
dtype
=
self
.
tensorizer_config
.
dtype
,
device
=
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
,
**
self
.
tensorizer_args
.
deserializer_params
)
as
deserializer
:
deserializer
.
load_into_module
(
self
.
model
)
end
=
time
.
perf_counter
()
...
...
@@ -400,33 +406,70 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
return
False
def
get_pretensorized_vllm_model
(
engine
:
"LLMEngine"
)
->
nn
.
Module
:
model
=
(
engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
def
serialize_vllm_model
(
model
:
nn
.
Module
,
tensorizer_config
:
TensorizerConfig
,
)
->
nn
.
Module
:
model
.
register_parameter
(
"vllm_tensorized_marker"
,
nn
.
Parameter
(
torch
.
tensor
((
1
,
),
device
=
"meta"
),
requires_grad
=
False
))
return
model
def
serialize_vllm_model
(
engine
:
"LLMEngine"
,
tensorizer_config
:
TensorizerConfig
,
encryption_key_path
:
Optional
[
str
]
=
None
)
\
->
nn
.
Module
:
model
=
get_pretensorized_vllm_model
(
engine
)
tensorizer_args
=
tensorizer_config
.
_construct_tensorizer_args
()
encryption_params
=
None
if
encryption_key_path
is
not
None
:
encryption_params
=
EncryptionParams
.
random
()
with
_write_stream
(
encryption_key_path
,
**
tensorizer_args
.
stream_params
)
as
stream
:
stream
.
write
(
encryption_params
.
key
)
if
(
keyfile
:
=
tensorizer_config
.
encryption_keyfile
)
is
not
None
:
with
open
(
keyfile
,
"rb"
)
as
f
:
key
=
f
.
read
()
encryption_params
=
EncryptionParams
(
key
=
key
)
with
_write_stream
(
tensorizer_args
.
tensorizer_uri
,
**
tensorizer_args
.
stream_params
)
as
stream
:
output_file
=
tensorizer_args
.
tensorizer_uri
if
tensorizer_config
.
_is_sharded
:
from
vllm.distributed
import
get_tensor_model_parallel_rank
output_file
=
output_file
%
get_tensor_model_parallel_rank
()
with
_write_stream
(
output_file
,
**
tensorizer_args
.
stream_params
)
as
stream
:
serializer
=
TensorSerializer
(
stream
,
encryption
=
encryption_params
)
serializer
.
write_module
(
model
)
serializer
.
close
()
logger
.
info
(
"Successfully serialized model to %s"
,
str
(
tensorizer_args
.
tensorizer_uri
))
logger
.
info
(
"Successfully serialized model to %s"
,
str
(
output_file
))
return
model
def
tensorize_vllm_model
(
engine_args
:
EngineArgs
,
tensorizer_config
:
TensorizerConfig
,
generate_keyfile
:
bool
=
True
):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config
=
engine_args
.
create_engine_config
()
tensorizer_config
.
verify_with_model_config
(
engine_config
.
model_config
)
tensorizer_config
.
verify_with_parallel_config
(
engine_config
.
parallel_config
)
# generate the encryption key before creating the engine to support sharding
if
generate_keyfile
and
(
keyfile
:
=
tensorizer_config
.
encryption_keyfile
)
is
not
None
:
encryption_params
=
EncryptionParams
.
random
()
with
_write_stream
(
keyfile
,
s3_access_key_id
=
tensorizer_config
.
s3_access_key_id
,
s3_secret_access_key
=
tensorizer_config
.
s3_secret_access_key
,
s3_endpoint
=
tensorizer_config
.
s3_endpoint
,
)
as
stream
:
stream
.
write
(
encryption_params
.
key
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
if
tensorizer_config
.
_is_sharded
:
# if the engine is a distributed engine (for tensor parallel) then each
# worker shard needs to serialize its part of the model.
engine
.
model_executor
.
_run_workers
(
"save_tensorized_model"
,
tensorizer_config
=
tensorizer_config
,
)
else
:
# with a single worker, we can get to the underlying model directly
serialize_vllm_model
(
engine
.
model_executor
.
driver_worker
.
model_runner
.
model
,
tensorizer_config
,
)
vllm/model_executor/models/llava.py
View file @
6640dc0b
...
...
@@ -227,7 +227,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
"""Run forward pass for L
lava
1.5.
"""Run forward pass for L
LaVA-
1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
...
...
@@ -247,22 +247,25 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
This model has two modes of image inputs:
`PIXEL_VALUES` and `IMAGE_FEATURES`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, 576, 1024].
pixel_values: The pixels in each input image.
Expects a batch with shape `[1, 3, 336, 336]`.
(Only applicable to `PIXEL_VALUES` mode)
image_features: The image features for each input image outputted by
the vision tower before passing to the multi-modal projector.
Expects a batch with shape `[1, 576, 1024]`.
(Only applicable to `IMAGE_FEATURES` mode)
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
- `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment