Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99b471c2
Commit
99b471c2
authored
May 21, 2024
by
zhuwenwen
Browse files
merge v0.4.1
parents
1925d2e9
468d761b
Changes
336
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1654 additions
and
184 deletions
+1654
-184
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+6
-7
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+61
-48
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+3
-4
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
...configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
...configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
...configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
...configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
...=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
+146
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+76
-19
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+95
-48
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+11
-3
vllm/model_executor/layers/ops/sample.py
vllm/model_executor/layers/ops/sample.py
+2
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+8
-3
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+373
-0
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+20
-15
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+139
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+29
-23
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+16
-12
vllm/model_executor/layers/quantization/schema.py
vllm/model_executor/layers/quantization/schema.py
+84
-0
No files found.
vllm/model_executor/guided_decoding.py
→
vllm/model_executor/guided_decoding
/outlines_decoding
.py
View file @
99b471c2
...
@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
...
@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
CompletionRequest
)
from
vllm.model_executor.guided_logits_processors
import
(
CFGLogitsProcessor
,
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
JSONLogitsProcessor
,
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
RegexLogitsProcessor
)
class
GuidedDecodingMode
(
Enum
):
class
GuidedDecodingMode
(
Enum
):
...
@@ -54,9 +53,9 @@ pair : UNESCAPED_STRING ":" value
...
@@ -54,9 +53,9 @@ pair : UNESCAPED_STRING ":" value
global_thread_pool
=
None
# used for generating logits processor fsm
global_thread_pool
=
None
# used for generating logits processor fsm
async
def
get_guided_decoding_logits_processor
(
async
def
get_
outlines_
guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
,
None
]:
"""
"""
Given an OpenAI-compatible request, check for guided decoding parameters
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
and get the necessary logits processor for the given guide.
...
@@ -85,13 +84,13 @@ async def get_guided_decoding_logits_processor(
...
@@ -85,13 +84,13 @@ async def get_guided_decoding_logits_processor(
def
_get_guide_and_mode
(
def
_get_guide_and_mode
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
)
->
Union
[
Tuple
[
str
,
GuidedDecodingMode
]
,
Tuple
[
None
,
None
]]
:
if
request
.
guided_json
:
if
request
.
guided_json
:
json
=
request
.
guided_json
json
=
request
.
guided_json
if
isinstance
(
json
,
dict
):
if
isinstance
(
json
,
dict
):
# turn dict into hashable string
# turn dict into hashable string
json
=
json_dumps
(
json
,
sort_keys
=
True
)
json
=
json_dumps
(
json
)
elif
isinstance
(
json
,
BaseModel
):
elif
isinstance
(
json
,
BaseModel
):
# use pydantic signature so that different model classes
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
# with the same fields will get hashed the same
...
...
vllm/model_executor/guided_logits_processors.py
→
vllm/model_executor/guided_
decoding/outlines_
logits_processors.py
View file @
99b471c2
...
@@ -13,13 +13,15 @@
...
@@ -13,13 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
json
import
json
import
math
import
math
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
outlines.fsm.fsm
import
CFGFSM
,
RegexFSM
from
outlines.fsm.fsm
import
CFGFSM
,
FSM
,
RegexFSM
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -27,49 +29,9 @@ from transformers import PreTrainedTokenizerBase
...
@@ -27,49 +29,9 @@ from transformers import PreTrainedTokenizerBase
class
BaseLogitsProcessor
:
class
BaseLogitsProcessor
:
def
adapt_tokenizer
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
def
__init__
(
self
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
# Child class should use initialize in their init.
self
.
fsm
:
FSM
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
]
)
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
def
init_state
(
self
):
def
init_state
(
self
):
"""Initialize the FSM states."""
"""Initialize the FSM states."""
...
@@ -78,7 +40,6 @@ class BaseLogitsProcessor:
...
@@ -78,7 +40,6 @@ class BaseLogitsProcessor:
def
__call__
(
self
,
input_ids
:
List
[
int
],
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Use the FSM to bias the logits before sampling the next token."""
"""Use the FSM to bias the logits before sampling the next token."""
seq_id
=
hash
(
tuple
(
input_ids
))
seq_id
=
hash
(
tuple
(
input_ids
))
if
len
(
input_ids
)
==
0
:
if
len
(
input_ids
)
==
0
:
...
@@ -96,7 +57,6 @@ class BaseLogitsProcessor:
...
@@ -96,7 +57,6 @@ class BaseLogitsProcessor:
device
=
scores
.
device
)
device
=
scores
.
device
)
mask
[
allowed_tokens
]
=
0
mask
[
allowed_tokens
]
=
0
scores
.
add_
(
mask
)
scores
.
add_
(
mask
)
return
scores
return
scores
...
@@ -113,7 +73,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
...
@@ -113,7 +73,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
The model's tokenizer
"""
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
tokenizer
=
_
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
self
.
fsm
=
fsm
...
@@ -167,6 +127,59 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
...
@@ -167,6 +127,59 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
The model's tokenizer
"""
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
tokenizer
=
_
adapt_tokenizer
(
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
self
.
fsm
=
fsm
self
.
fsm
=
fsm
def
init_state
(
self
):
"""Initialize state with a CFGFSM copy."""
super
().
init_state
()
self
.
fsm
=
self
.
fsm
.
copy
()
@
lru_cache
def
_adapt_tokenizer
(
tokenizer
:
PreTrainedTokenizerBase
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
=
copy
.
deepcopy
(
tokenizer
)
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
])
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
vllm/model_executor/layers/activation.py
View file @
99b471c2
...
@@ -6,11 +6,10 @@ import torch
...
@@ -6,11 +6,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.utils
import
divide
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
...
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
99b471c2
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
99b471c2
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
99b471c2
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
99b471c2
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
0 → 100644
View file @
99b471c2
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"48"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
99b471c2
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
...
@@ -21,6 +21,8 @@ def fused_moe_kernel(
...
@@ -21,6 +21,8 @@ def fused_moe_kernel(
a_ptr
,
a_ptr
,
b_ptr
,
b_ptr
,
c_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
expert_ids_ptr
,
...
@@ -49,6 +51,7 @@ def fused_moe_kernel(
...
@@ -49,6 +51,7 @@ def fused_moe_kernel(
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8
:
tl
.
constexpr
,
):
):
"""
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Implements the fused computation for a Mixture of Experts (MOE) using
...
@@ -111,6 +114,10 @@ def fused_moe_kernel(
...
@@ -111,6 +114,10 @@ def fused_moe_kernel(
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
# -----------------------------------------------------------
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
...
@@ -129,7 +136,10 @@ def fused_moe_kernel(
...
@@ -129,7 +136,10 @@ def fused_moe_kernel(
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
other
=
0.0
)
# We accumulate along the K dimension.
# We accumulate along the K dimension.
accumulator
+=
tl
.
dot
(
a
,
b
)
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
...
@@ -140,7 +150,10 @@ def fused_moe_kernel(
...
@@ -140,7 +150,10 @@ def fused_moe_kernel(
other
=
0
)
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
if
use_fp8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# -----------------------------------------------------------
# Write back the block of the output
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
@@ -207,15 +220,24 @@ def moe_align_block_size(
...
@@ -207,15 +220,24 @@ def moe_align_block_size(
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
])
->
None
:
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
not
use_fp8
:
A_scale
=
None
assert
B_scale
is
None
else
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
)
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
...
@@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A
,
A
,
B
,
B
,
C
,
C
,
A_scale
,
B_scale
,
topk_weights
,
topk_weights
,
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
...
@@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
C
.
stride
(
2
),
C
.
stride
(
2
),
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
top_k
=
top_k
,
compute_type
=
tl
.
bfloat16
if
A
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
**
config
,
**
config
,
)
)
def
get_config_file_name
(
E
:
int
,
N
:
int
)
->
str
:
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
]
)
->
str
:
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}
.json"
dtype_selector
=
""
if
not
dtype
else
f
",dtype=
{
dtype
}
"
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}
.json"
@
functools
.
lru_cache
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
"""
Return optimized configurations for the fused MoE kernel.
Return optimized configurations for the fused MoE kernel.
...
@@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
...
@@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# First look up if an optimized configuration is available in the configs
# First look up if an optimized configuration is available in the configs
# directory
# directory
json_file_name
=
get_config_file_name
(
E
,
N
)
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
...
@@ -288,6 +315,9 @@ def fused_moe(
...
@@ -288,6 +315,9 @@ def fused_moe(
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -305,6 +335,12 @@ def fused_moe(
...
@@ -305,6 +335,12 @@ def fused_moe(
Defaults to False.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -358,7 +394,8 @@ def fused_moe(
...
@@ -358,7 +394,8 @@ def fused_moe(
config
=
override_config
config
=
override_config
else
:
else
:
# First try to load optimal config from the file
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
])
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
if
configs
:
# If an optimal configuration map has been found, look up the
# If an optimal configuration map has been found, look up the
...
@@ -394,17 +431,37 @@ def fused_moe(
...
@@ -394,17 +431,37 @@ def fused_moe(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
invoke_fused_moe_kernel
(
hidden_states
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
w1
,
expert_ids
,
num_tokens_post_padded
,
False
,
intermediate_cache1
,
topk_ids
.
shape
[
1
],
config
)
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
tl
.
float16
,
use_fp8
=
use_fp8
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
invoke_fused_moe_kernel
(
intermediate_cache2
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
w2
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
intermediate_cache3
,
config
)
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
tl
.
float16
,
use_fp8
=
use_fp8
)
if
inplace
:
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
...
...
vllm/model_executor/layers/layernorm.py
View file @
99b471c2
...
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
...
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
class
RMSNorm
(
nn
.
Module
):
class
RMSNorm
(
nn
.
Module
):
...
...
vllm/model_executor/layers/linear.py
View file @
99b471c2
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.utils
import
(
divide
,
split_tensor_along_last_dim
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -32,21 +32,43 @@ class LinearMethodBase(ABC):
...
@@ -32,21 +32,43 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
@
abstractmethod
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_size_per_partition
:
int
,
input_size
:
int
,
input_size_per_partition
:
int
,
output_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
"""Create weights for a linear layer."""
**
extra_weight_attrs
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
apply_weights
(
self
,
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Apply the weights to the input tensor."""
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class
UnquantizedLinearMethod
(
LinearMethodBase
):
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization.
"""Linear method without quantization.
...
@@ -60,22 +82,25 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -60,22 +82,25 @@ class UnquantizedLinearMethod(LinearMethodBase):
self
.
separate_bias_add
=
separate_bias_add
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_size_per_partition
:
int
,
input_size
:
int
,
input_size_per_partition
:
int
,
output_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
),
dtype
=
params_dtype
),
requires_grad
=
False
)
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
return
{
"weight"
:
weight
}
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply_weights
(
self
,
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
weights
[
"
weight
"
]
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
+
bias
...
@@ -124,12 +149,9 @@ class ReplicatedLinear(torch.nn.Module):
...
@@ -124,12 +149,9 @@ class ReplicatedLinear(torch.nn.Module):
if
linear_method
is
None
:
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
input_size
,
self
.
output_size
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
...
@@ -139,7 +161,7 @@ class ReplicatedLinear(torch.nn.Module):
...
@@ -139,7 +161,7 @@ class ReplicatedLinear(torch.nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output
=
self
.
linear_method
.
apply_weights
(
self
.
linear_weights
,
x
,
bias
)
output
=
self
.
linear_method
.
apply_weights
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
return
output
,
output_bias
...
@@ -162,6 +184,8 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -162,6 +184,8 @@ class ColumnParallelLinear(torch.nn.Module):
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
linear_method: (Maybe quantized) linear method.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -173,6 +197,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -173,6 +197,7 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -189,14 +214,16 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -189,14 +214,16 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
if
linear_method
is
None
:
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
linear_method
=
UnquantizedLinearMethod
()
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
output_size_per_partition
,
self
.
input_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
[
x
//
tp_size
for
x
in
output_sizes
],
for
name
,
weight
in
self
.
l
in
ear_weights
.
items
():
self
.
in
put_size
,
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
output_size
,
self
.
register_parameter
(
name
,
weight
)
self
.
params_dtype
,
set_weight_attrs
(
weight
,
{
"
weight_loader
"
:
self
.
weight_loader
}
)
weight_loader
=
self
.
weight_loader
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
torch
.
empty
(
self
.
output_size_per_partition
,
...
@@ -228,8 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -228,8 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
,
input_
,
bias
)
self
.
linear_weights
,
input_
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
@@ -273,16 +299,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -273,16 +299,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
skip_bias_add
,
params_dtype
,
linear_method
)
skip_bias_add
,
params_dtype
,
linear_method
,
self
.
output_sizes
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
):
loaded_shard_id
:
Optional
[
int
]
=
None
):
param_data
=
param
.
data
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
# Loaded weight is already packed.
if
output_dim
is
None
:
if
output_dim
is
None
:
...
@@ -339,6 +368,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -339,6 +368,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
else
:
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
if
not
ignore_warning
:
...
@@ -412,8 +446,14 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -412,8 +446,14 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size
=
self
.
hidden_size
input_size
=
self
.
hidden_size
output_size
=
(
self
.
num_heads
+
output_size
=
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
2
*
self
.
num_kv_heads
)
*
tp_size
*
self
.
head_size
output_sizes
=
[
self
.
num_heads
*
tp_size
*
self
.
head_size
,
self
.
num_kv_heads
*
tp_size
*
self
.
head_size
,
self
.
num_kv_heads
*
tp_size
*
self
.
head_size
]
super
().
__init__
(
input_size
,
output_size
,
bias
,
False
,
skip_bias_add
,
super
().
__init__
(
input_size
,
output_size
,
bias
,
False
,
skip_bias_add
,
params_dtype
,
linear_method
)
params_dtype
,
linear_method
,
output_sizes
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
def
weight_loader
(
self
,
...
@@ -422,6 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -422,6 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id
:
Optional
[
str
]
=
None
):
loaded_shard_id
:
Optional
[
str
]
=
None
):
param_data
=
param
.
data
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
# Loaded weight is already packed.
...
@@ -493,6 +534,12 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -493,6 +534,12 @@ class QKVParallelLinear(ColumnParallelLinear):
start_idx
=
shard_id
*
shard_size
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
else
:
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
if
not
ignore_warning
:
...
@@ -566,13 +613,13 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -566,13 +613,13 @@ class RowParallelLinear(torch.nn.Module):
if
linear_method
is
None
:
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
input_size
,
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
params_dtype
)
[
self
.
output_size
],
for
name
,
weight
in
self
.
l
in
ear_weights
.
items
():
self
.
in
put_size
,
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
output_size
,
self
.
register_parameter
(
name
,
weight
)
self
.
params_dtype
,
set_weight_attrs
(
weight
,
{
"
weight_loader
"
:
self
.
weight_loader
}
)
weight_loader
=
self
.
weight_loader
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
raise
ValueError
(
"When not reduce the results, adding bias to the "
...
@@ -616,7 +663,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -616,7 +663,7 @@ class RowParallelLinear(torch.nn.Module):
# Matrix multiply.
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
.
linear_weights
,
input_parallel
)
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
...
...
vllm/model_executor/layers/logits_processor.py
View file @
99b471c2
...
@@ -4,8 +4,7 @@ from typing import Optional
...
@@ -4,8 +4,7 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.distributed
import
tensor_model_parallel_gather
tensor_model_parallel_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -86,8 +85,16 @@ def _apply_logits_processors(
...
@@ -86,8 +85,16 @@ def _apply_logits_processors(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
logits_row_idx
=
0
logits_row_idx
=
0
found_logits_processors
=
False
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
sampling_metadata
.
seq_groups
:
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
logits_processors
=
sampling_params
.
logits_processors
logits_processors
=
sampling_params
.
logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
logits_row_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
if
logits_processors
:
if
logits_processors
:
found_logits_processors
=
True
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
...
@@ -100,5 +107,6 @@ def _apply_logits_processors(
...
@@ -100,5 +107,6 @@ def _apply_logits_processors(
else
:
else
:
logits_row_idx
+=
len
(
seq_ids
)
logits_row_idx
+=
len
(
seq_ids
)
if
found_logits_processors
:
if
found_logits_processors
:
# verifies that no rows in logits were missed unexpectedly
assert
logits_row_idx
==
logits
.
shape
[
0
]
assert
logits_row_idx
==
logits
.
shape
[
0
]
return
logits
return
logits
vllm/model_executor/layers/ops/sample.py
View file @
99b471c2
...
@@ -29,8 +29,8 @@ def _multi_split_sample(
...
@@ -29,8 +29,8 @@ def _multi_split_sample(
sampled_tokens_size
:
Tuple
[
int
,
int
],
sampled_tokens_size
:
Tuple
[
int
,
int
],
sampled_logprobs_size
:
Tuple
[
int
,
int
],
sampled_logprobs_size
:
Tuple
[
int
,
int
],
sample_indices
:
torch
.
Tensor
,
sample_indices
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
*
,
*
,
logprobs
:
Optional
[
torch
.
Tensor
]
=
None
,
modify_greedy_probs
:
bool
=
False
,
modify_greedy_probs
:
bool
=
False
,
save_logprobs
:
bool
=
False
,
save_logprobs
:
bool
=
False
,
):
):
...
@@ -167,6 +167,7 @@ def sample(
...
@@ -167,6 +167,7 @@ def sample(
sampled_logprobs_size
=
(
0
,
0
)
sampled_logprobs_size
=
(
0
,
0
)
logprobs
=
probs
logprobs
=
probs
assert
logprobs
is
not
None
if
_save_modified_probs
:
if
_save_modified_probs
:
sampled_modified_probs_size
=
sampled_tokens_size
sampled_modified_probs_size
=
sampled_tokens_size
else
:
else
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
99b471c2
from
typing
import
Type
from
typing
import
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
FP8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
_QUANTIZATION_CONFIG_REGISTRY
=
{
QUANTIZATION_METHODS
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"fp8"
:
FP8Config
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
...
@@ -16,12 +20,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
...
@@ -16,12 +20,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
_
QUANTIZATION_
CONFIG_REGISTRY
:
if
quantization
not
in
QUANTIZATION_
METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
_
QUANTIZATION_
CONFIG_REGISTRY
[
quantization
]
return
QUANTIZATION_
METHODS
[
quantization
]
__all__
=
[
__all__
=
[
"QuantizationConfig"
,
"QuantizationConfig"
,
"get_quantization_config"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
]
vllm/model_executor/layers/quantization/aqlm.py
0 → 100644
View file @
99b471c2
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
vllm._C
import
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
def
get_int_dtype
(
nbits
:
int
)
->
torch
.
dtype
:
if
nbits
<=
8
:
return
torch
.
int8
if
nbits
<=
16
:
return
torch
.
int16
if
nbits
<=
32
:
return
torch
.
int32
if
nbits
<=
64
:
return
torch
.
int64
raise
ValueError
(
f
"No dtype available for
{
nbits
}
-bit codebooks"
)
@
torch
.
inference_mode
()
def
unpack_int_data
(
data
:
torch
.
IntTensor
,
nbits
:
int
)
->
torch
.
IntTensor
:
return
data
.
to
(
torch
.
int64
)
%
(
2
**
nbits
)
def
dequantize_weight
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape
[*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code,
[num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be
broadcastble with
[*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape
[*dims, num_in_groups*group_size]
"""
num_out_groups
,
num_in_groups
,
num_codebooks
=
codes
.
shape
[
-
3
:]
num_codebooks
,
codebook_size
,
out_group_size
,
in_group_size
=
\
codebooks
.
shape
out_features
=
num_out_groups
*
out_group_size
in_features
=
num_in_groups
*
in_group_size
codebook_offsets
=
torch
.
arange
(
0
,
num_codebooks
*
codebook_size
,
codebook_size
,
device
=
codes
.
device
)
# shape: [num_codebooks]
reconstructed_weight_flat
=
F
.
embedding_bag
(
codes
.
flatten
(
0
,
-
2
)
+
codebook_offsets
,
codebooks
.
flatten
(
0
,
1
).
flatten
(
-
2
,
-
1
),
mode
=
"sum"
)
# [prod(dims) * num_out_groups * num_in_groups, out_group_size
# * in_group_size]
reconstructed_weight_groupwise
=
reconstructed_weight_flat
.
view
(
list
(
codes
.
shape
[:
-
3
])
+
[
num_out_groups
,
num_in_groups
,
out_group_size
,
in_group_size
])
if
scales
is
not
None
:
reconstructed_weight_groupwise
=
reconstructed_weight_groupwise
.
mul
(
scales
)
return
reconstructed_weight_groupwise
.
swapaxes
(
-
3
,
-
2
).
reshape
(
list
(
codes
.
shape
[:
-
3
])
+
[
out_features
,
in_features
])
def
dequantize_gemm
(
input
:
torch
.
Tensor
,
# [..., in_features]
codes
:
torch
.
IntTensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
dequantized_weight
=
dequantize_weight
(
unpack_int_data
(
codes
,
codebooks
.
shape
[
1
].
bit_length
()
-
1
),
codebooks
,
scales
,
)
return
F
.
linear
(
input
,
dequantized_weight
,
bias
)
# Generic dequantization, slow but flexible.
def
generic_dequantize_gemm
(
input
:
torch
.
Tensor
,
# [..., in_features]
codes
:
torch
.
IntTensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output_shape
=
input
.
shape
[:
-
1
]
+
(
scales
.
shape
[
0
],
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
num_outputs
=
len
(
output_partition_sizes
)
# break the inputs and codebooks apart then combine the outputs.
# Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
# multiply at the end.
num_codebooks
=
codebooks
.
shape
[
0
]
//
num_outputs
assert
(
scales
.
shape
[
0
]
==
codes
.
shape
[
0
])
assert
(
sum
(
output_partition_sizes
)
==
scales
.
shape
[
0
])
output_offset
=
0
codebooks_offset
=
0
for
output_size
in
output_partition_sizes
:
shard_output
=
dequantize_gemm
(
input
,
codes
.
narrow
(
0
,
output_offset
,
output_size
),
codebooks
.
narrow
(
0
,
codebooks_offset
,
num_codebooks
),
scales
.
narrow
(
0
,
output_offset
,
output_size
),
None
if
bias
is
None
else
bias
.
narrow
(
0
,
output_offset
,
output_size
))
output_slice
=
output
.
narrow
(
-
1
,
output_offset
,
output_size
)
assert
(
output_slice
.
shape
==
shard_output
.
shape
)
output_slice
.
copy_
(
shard_output
)
output_offset
+=
output_size
codebooks_offset
+=
num_codebooks
return
output
# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
# at 6 and 9 times faster than the generic version above, respectively.
def
optimized_dequantize_gemm
(
input
:
torch
.
Tensor
,
# [..., in_features]
codes
:
torch
.
IntTensor
,
# [num_out_groups, num_in_groups, num_codebooks]
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
output_partition_sizes
:
torch
.
IntTensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
if
bias
is
None
:
# scaling the output is fastest, so we do that when possible.
output
=
F
.
linear
(
input
,
weights
,
bias
)
orig_shape
=
output
.
shape
flattened_output
=
output
.
view
(
-
1
,
output
.
size
(
-
1
))
f_scales
=
scales
.
view
(
-
1
,
scales
.
shape
[
0
])
b_scales
=
f_scales
.
expand
(
flattened_output
.
shape
[
0
],
-
1
)
flattened_output
*=
b_scales
return
output
.
view
(
orig_shape
)
else
:
b_scales
=
scales
.
view
(
scales
.
shape
[:
-
3
]
+
(
-
1
,
)).
expand
(
-
1
,
weights
.
shape
[
1
])
weights
*=
b_scales
return
F
.
linear
(
input
,
weights
,
bias
)
class
AQLMConfig
(
QuantizationConfig
):
"""Config class for AQLM.
Reference: https://github.com/Vahe1994/AQLM
"""
def
__init__
(
self
,
in_group_size
:
int
,
nbits_per_codebook
:
int
,
num_codebooks
:
int
,
out_group_size
:
int
,
)
->
None
:
self
.
in_group_size
=
in_group_size
self
.
nbits_per_codebook
=
nbits_per_codebook
self
.
num_codebooks
=
num_codebooks
self
.
out_group_size
=
out_group_size
# out_group_size > 1 is untested, and probably won't work as-is.
assert
(
self
.
out_group_size
==
1
)
self
.
pack_factor
=
(
self
.
in_group_size
*
self
.
out_group_size
)
def
__repr__
(
self
)
->
str
:
return
(
f
"AQLMConfig(in_group_size=
{
self
.
in_group_size
}
, "
f
"nbits_per_codebook=
{
self
.
nbits_per_codebook
}
, "
f
"num_codebooks=
{
self
.
num_codebooks
}
, "
f
"out_group_size=
{
self
.
out_group_size
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"aqlm"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
# no extra configs.
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AQLMConfig"
:
in_group_size
=
cls
.
get_from_keys
(
config
,
[
"in_group_size"
])
nbits_per_codebook
=
cls
.
get_from_keys
(
config
,
[
"nbits_per_codebook"
])
num_code_books
=
cls
.
get_from_keys
(
config
,
[
"num_codebooks"
])
out_group_size
=
cls
.
get_from_keys
(
config
,
[
"out_group_size"
])
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
out_group_size
)
def
get_linear_method
(
self
)
->
"AQLMLinearMethod"
:
return
AQLMLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
AQLMLinearMethod
(
LinearMethodBase
):
"""Linear method for AQLM.
Args:
quant_config: The AQLM quantization config.
"""
def
__init__
(
self
,
quant_config
:
AQLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
del
output_size
# Unused.
del
input_size
# Unused.
if
params_dtype
!=
torch
.
half
:
raise
ValueError
(
"Only half is currently supported by aqlm"
)
if
input_size_per_partition
%
self
.
quant_config
.
in_group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
out_group_size
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
codes
=
Parameter
(
torch
.
empty
(
# There could actually be two pack factors, one along input and
# one along output, but we don't currently support
# out_group_size, and only the one along output needs to be
# marked with "packed_dim" in order for QKVLinear to work.
output_size_per_partition
,
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
self
.
quant_config
.
num_codebooks
,
dtype
=
get_int_dtype
(
self
.
quant_config
.
nbits_per_codebook
),
),
requires_grad
=
False
,
)
set_weight_attrs
(
codes
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
codebooks
=
Parameter
(
torch
.
empty
(
self
.
quant_config
.
num_codebooks
*
len
(
output_partition_sizes
),
2
**
self
.
quant_config
.
nbits_per_codebook
,
self
.
quant_config
.
out_group_size
,
self
.
quant_config
.
in_group_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
codebooks
,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata"
:
True
,
"output_partition_sizes"
:
torch
.
tensor
(
output_partition_sizes
,
device
=
'cpu'
),
},
)
scales
=
Parameter
(
torch
.
empty
(
(
output_size_per_partition
//
self
.
quant_config
.
out_group_size
,
1
,
1
,
1
,
),
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"output_dim"
:
0
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
out_group_size
},
)
layer
.
register_parameter
(
"codes"
,
codes
)
set_weight_attrs
(
codes
,
extra_weight_attrs
)
layer
.
register_parameter
(
"codebooks"
,
codebooks
)
set_weight_attrs
(
codebooks
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
codebooks
=
layer
.
codebooks
codes
=
layer
.
codes
scales
=
layer
.
scales
output_partition_sizes
=
getattr
(
codebooks
,
"output_partition_sizes"
,
None
)
nbooks
=
codes
.
shape
[
2
]
ingroups
=
codebooks
.
shape
[
3
]
outgroups
=
codebooks
.
shape
[
2
]
bits
=
codebooks
.
shape
[
1
]
# We support these formats with dedicated gemm and decompression
# kernels.
if
ingroups
==
8
and
outgroups
==
1
and
(
(
bits
==
256
and
nbooks
==
2
)
or
(
bits
==
65536
and
nbooks
==
1
)):
# thresholds determined by timings on an A6000, one GPU
use_gemv
=
math
.
prod
(
x
.
shape
[:
-
1
])
<=
6
return
ops
.
aqlm_gemm
(
x
,
codes
,
codebooks
,
scales
,
output_partition_sizes
,
bias
,
)
if
use_gemv
else
optimized_dequantize_gemm
(
x
,
codes
,
codebooks
,
scales
,
output_partition_sizes
,
bias
,
)
# fall back all unoptimized formats
return
generic_dequantize_gemm
(
x
,
codes
,
codebooks
,
scales
,
output_partition_sizes
,
bias
,
)
vllm/model_executor/layers/quantization/awq.py
View file @
99b471c2
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -79,15 +79,18 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -79,15 +79,18 @@ class AWQLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_size_per_partition
:
int
,
input_size
:
int
,
input_size_per_partition
:
int
,
output_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The input size is not aligned with the quantized "
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The output size is not aligned with the quantized "
"The output size is not aligned with the quantized "
...
@@ -136,19 +139,21 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -136,19 +139,21 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim"
:
0
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"output_dim"
:
1
,
})
})
return
{
"qweight"
:
qweight
,
layer
.
register_parameter
(
"qweight"
,
qweight
)
"qzeros"
:
qzeros
,
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
"scales"
:
scales
,
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
}
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply_weights
(
self
,
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"
qweight
"
]
qweight
=
layer
.
qweight
scales
=
weights
[
"
scales
"
]
scales
=
layer
.
scales
qzeros
=
weights
[
"
qzeros
"
]
qzeros
=
layer
.
qzeros
pack_factor
=
self
.
quant_config
.
pack_factor
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,
))
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,
))
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
@@ -163,5 +168,5 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -163,5 +168,5 @@ class AWQLinearMethod(LinearMethodBase):
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
)
pack_factor
)
if
bias
is
not
None
:
if
bias
is
not
None
:
out
=
out
+
bias
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/fp8.py
0 → 100644
View file @
99b471c2
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
FP8Config
(
QuantizationConfig
):
"""Config class for FP8."""
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return
90
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"FP8Config"
:
return
cls
()
def
get_linear_method
(
self
)
->
"Fp8LinearMethod"
:
return
Fp8LinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
scaling factor will be initialized after the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
FP8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
set_weight_attrs
(
weight
,
extra_weight_attrs
)
w_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"weight_scaling_factor"
,
w_scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Although the linear_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if
not
hasattr
(
layer
,
"weight_scaling_factor"
):
return
qweight
,
weight_scale
=
per_tensor_quantize
(
layer
.
weight
)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scaling_factor
.
data
.
copy_
(
weight_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qinput
,
x_scale
=
per_tensor_quantize
(
x
)
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scaling_factor
,
bias
=
bias
,
)
return
output
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
float
]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val
,
max_val
=
tensor
.
aminmax
()
amax
=
min_val
.
abs
().
max
(
max_val
.
abs
())
scale
=
finfo
.
max
/
amax
.
clamp
(
min
=
1e-12
)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight
=
(
tensor
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
scale
=
scale
.
float
().
reciprocal
()
return
qweight
,
scale
vllm/model_executor/layers/quantization/gptq.py
View file @
99b471c2
...
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
...
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -89,18 +89,21 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -89,18 +89,21 @@ class GPTQLinearMethod(LinearMethodBase):
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_
size_per_partition
:
int
,
output_
partition_sizes
:
List
[
int
]
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
**
extra_weight_attrs
,
):
del
output_size
# Unused.
del
output_size
# Unused.
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The input size is not aligned with the quantized "
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
(
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
if
(
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
!=
0
):
!=
0
):
raise
ValueError
(
raise
ValueError
(
...
@@ -179,37 +182,40 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -179,37 +182,40 @@ class GPTQLinearMethod(LinearMethodBase):
"input_dim"
:
scale_and_zero_input_dim
,
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
"output_dim"
:
1
,
})
})
return
{
"qweight"
:
qweight
,
layer
.
register_parameter
(
"qweight"
,
qweight
)
"g_idx"
:
g_idx
,
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
"qzeros"
:
qzeros
,
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
"scales"
:
scales
,
set_weight_attrs
(
g_idx
,
extra_weight_attrs
)
"exllama_state"
:
exllama_state
,
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
}
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
exllama_state
=
exllama_state
def
apply_weights
(
self
,
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"
qweight
"
]
qweight
=
layer
.
qweight
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# exllama needs to shuffle the weight after the weight is loaded
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
# here we do the shuffle on first forward pass
if
weights
[
"
exllama_state
"
]
==
ExllamaState
.
UNINITIALIZED
:
if
layer
.
exllama_state
==
ExllamaState
.
UNINITIALIZED
:
if
self
.
quant_config
.
desc_act
:
if
self
.
quant_config
.
desc_act
:
weights
[
"g_idx"
]
=
torch
.
argsort
(
weights
[
"g_idx"
]).
to
(
layer
.
g_idx
.
data
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
torch
.
int
)
else
:
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
layer
.
g_idx
.
data
=
torch
.
empty
((
0
,
),
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
device
=
layer
.
g_idx
.
device
)
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
layer
.
exllama_state
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
layer
.
qweight
,
layer
.
g_idx
,
self
.
quant_config
.
weight_bits
)
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
output
=
ops
.
gptq_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
qzeros
,
weights
[
"qzeros"
],
weights
[
"scales"
],
layer
.
scales
,
layer
.
g_idx
,
weights
[
"g_idx"
],
layer
.
exllama_state
==
ExllamaState
.
READY
,
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
self
.
quant_config
.
weight_bits
)
self
.
quant_config
.
weight_bits
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
.
add_
(
bias
)
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
View file @
99b471c2
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase):
...
@@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase):
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_
size_per_partition
:
int
,
output_
partition_sizes
:
List
[
int
]
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
**
extra_weight_attrs
,
):
del
output_size
# Unused.
del
output_size
# Unused.
if
params_dtype
!=
torch
.
float16
:
if
params_dtype
!=
torch
.
float16
:
...
@@ -104,6 +106,7 @@ class MarlinLinearMethod(LinearMethodBase):
...
@@ -104,6 +106,7 @@ class MarlinLinearMethod(LinearMethodBase):
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"Weight output_size_per_partition = "
...
@@ -187,21 +190,22 @@ class MarlinLinearMethod(LinearMethodBase):
...
@@ -187,21 +190,22 @@ class MarlinLinearMethod(LinearMethodBase):
dtype
=
torch
.
int
),
dtype
=
torch
.
int
),
requires_grad
=
False
)
requires_grad
=
False
)
return
{
layer
.
register_parameter
(
"B"
,
qweight
)
"B"
:
qweight
,
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
"s"
:
scales
,
layer
.
register_parameter
(
"s"
,
scales
)
"workspace"
:
workspace
,
set_weight_attrs
(
scales
,
extra_weight_attrs
)
}
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qweight
=
weights
[
"B"
]
qweight
=
layer
.
B
scales
=
weights
[
"s"
]
scales
=
layer
.
s
workspace
=
weights
[
"
workspace
"
]
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
...
...
vllm/model_executor/layers/quantization/schema.py
0 → 100644
View file @
99b471c2
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from
typing
import
Dict
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
,
ValidationInfo
,
model_validator
class
KVCacheQuantSchema
(
BaseModel
):
dtype
:
str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor
:
Dict
[
int
,
Dict
[
int
,
float
]]
@
model_validator
(
mode
=
"after"
)
def
check_is_fp8
(
self
)
->
"KVCacheQuantSchema"
:
assert
self
.
dtype
==
"float8_e4m3fn"
,
(
"Loaded scaling factors intended for KV cache dtype = "
f
"
{
self
.
dtype
}
rather than float8_e4m3fn!"
)
return
self
@
model_validator
(
mode
=
"after"
)
def
check_tp_ranks
(
self
,
info
:
ValidationInfo
)
->
"KVCacheQuantSchema"
:
context
=
info
.
context
if
context
:
tp_size
=
context
[
"tp_size"
]
num_hidden_layers
=
context
[
"num_hidden_layers"
]
assert
len
(
self
.
scaling_factor
)
==
tp_size
,
(
f
"Loaded dictionary has TP size
{
len
(
self
.
scaling_factor
)
}
"
f
"but LLM engine is currently running with TP size
{
tp_size
}
."
)
for
tp_rank
,
layer_maps
in
self
.
scaling_factor
.
items
():
assert
len
(
layer_maps
)
==
num_hidden_layers
,
(
f
"KV cache scales map for TP rank
{
tp_rank
}
is malformed. "
f
"Expected
{
num_hidden_layers
}
layers, got "
f
"
{
len
(
layer_maps
)
}
."
)
for
i
in
range
(
tp_size
):
assert
i
in
self
.
scaling_factor
,
(
f
"KV cache scales map for TP rank
{
i
}
not found."
)
return
self
@
model_validator
(
mode
=
"after"
)
def
check_current_rank
(
self
,
info
:
ValidationInfo
)
->
"KVCacheQuantSchema"
:
context
=
info
.
context
if
context
:
tp_rank
=
context
[
"tp_rank"
]
num_hidden_layers
=
context
[
"num_hidden_layers"
]
layer_scales_map
=
self
.
scaling_factor
[
tp_rank
]
for
i
in
range
(
num_hidden_layers
):
assert
i
in
layer_scales_map
,
(
f
"Could not find KV cache scales for layer
{
i
}
in "
f
"TP rank
{
tp_rank
}
."
)
return
self
class
QuantParamSchema
(
BaseModel
):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config
=
ConfigDict
(
protected_namespaces
=
())
model_type
:
Optional
[
str
]
kv_cache
:
KVCacheQuantSchema
@
model_validator
(
mode
=
"after"
)
def
check_model_type
(
self
,
info
:
ValidationInfo
)
->
"QuantParamSchema"
:
context
=
info
.
context
if
context
:
model_type
=
context
.
get
(
"model_type"
,
None
)
if
model_type
is
not
None
:
assert
model_type
==
self
.
model_type
,
(
f
"Model type is
{
model_type
}
but loaded "
f
"scaling factors belonging to different "
f
"model type
{
self
.
model_type
}
!"
)
return
self
Prev
1
…
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment