Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
217ee621
Commit
217ee621
authored
Dec 05, 2024
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev
parents
f0021a4d
3f78216a
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
187 additions
and
44 deletions
+187
-44
tests/kernels/untest_machete_gemm.py
tests/kernels/untest_machete_gemm.py
+0
-0
tests/kernels/untest_mamba_ssm.py
tests/kernels/untest_mamba_ssm.py
+0
-0
tests/kernels/untest_marlin_gemm.py
tests/kernels/untest_marlin_gemm.py
+0
-0
tests/kernels/untest_permute_cols.py
tests/kernels/untest_permute_cols.py
+24
-0
tests/kernels/utils.py
tests/kernels/utils.py
+20
-17
tests/models/decoder_only/vision_language/test_intern_vit.py
tests/models/decoder_only/vision_language/test_intern_vit.py
+3
-5
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+3
-2
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+2
-2
tests/tokenization/test_tokenizer.py
tests/tokenization/test_tokenizer.py
+8
-2
tests/utils.py
tests/utils.py
+31
-1
vllm/__init__.py
vllm/__init__.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-2
vllm/attention/selector.py
vllm/attention/selector.py
+1
-1
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+2
-2
vllm/config.py
vllm/config.py
+1
-1
vllm/distributed/device_communicators/pynccl_wrapper.py
vllm/distributed/device_communicators/pynccl_wrapper.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+28
-3
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+56
-0
No files found.
tests/kernels/test_machete_gemm.py
→
tests/kernels/
un
test_machete_gemm.py
View file @
217ee621
File moved
tests/kernels/test_mamba_ssm.py
→
tests/kernels/
un
test_mamba_ssm.py
View file @
217ee621
File moved
tests/kernels/test_marlin_gemm.py
→
tests/kernels/
un
test_marlin_gemm.py
View file @
217ee621
File moved
tests/kernels/test_permute_cols.py
→
tests/kernels/
un
test_permute_cols.py
View file @
217ee621
...
...
@@ -3,13 +3,22 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
permute_cols
from
.utils
import
torch_version
@
pytest
.
mark
.
parametrize
(
'shape'
,
[(
1
,
512
),
(
544
,
4096
),
(
67
,
8192
)])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_permute_cols
(
shape
,
dtype
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
opcheck
(
torch
.
ops
.
_C
.
permute_cols
,
(
x
,
perm
))
y
=
permute_cols
(
x
,
perm
)
torch
.
testing
.
assert_close
(
y
,
x
[:,
perm
])
\ No newline at end of file
if
torch_version
.
startswith
(
"2.3"
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
y
=
permute_cols
(
x
,
perm
)
torch
.
allclose
(
y
,
x
[:,
perm
])
elif
torch_version
.
startswith
(
"2.4"
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
opcheck
(
torch
.
ops
.
_C
.
permute_cols
,
(
x
,
perm
))
y
=
permute_cols
(
x
,
perm
)
torch
.
testing
.
assert_close
(
y
,
x
[:,
perm
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
\ No newline at end of file
tests/kernels/utils.py
View file @
217ee621
...
...
@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_aot_dispatch_dynamic"
,
)
torch_version
=
torch
.
__version__
class
QKVInputs
(
NamedTuple
):
'''
...
...
@@ -974,20 +976,21 @@ def fp8_allclose(
equal_nan
=
equal_nan
)).
item
())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
Tuple
[
Any
,
...],
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
*
,
test_utils
:
Union
[
str
,
Sequence
[
str
]]
=
ALL_OPCHECK_TEST_UTILS
,
raise_exception
:
bool
=
True
,
cond
:
bool
=
True
)
->
Dict
[
str
,
str
]:
with
unittest
.
mock
.
patch
(
'torch.allclose'
,
new
=
fp8_allclose
):
return
torch
.
library
.
opcheck
(
op
,
args
,
kwargs
,
test_utils
=
test_utils
,
raise_exception
=
raise_exception
)
if
cond
else
{}
if
torch_version
.
startswith
(
"2.4"
):
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
Tuple
[
Any
,
...],
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
*
,
test_utils
:
Union
[
str
,
Sequence
[
str
]]
=
ALL_OPCHECK_TEST_UTILS
,
raise_exception
:
bool
=
True
,
cond
:
bool
=
True
)
->
Dict
[
str
,
str
]:
with
unittest
.
mock
.
patch
(
'torch.allclose'
,
new
=
fp8_allclose
):
return
torch
.
library
.
opcheck
(
op
,
args
,
kwargs
,
test_utils
=
test_utils
,
raise_exception
=
raise_exception
)
if
cond
else
{}
tests/models/decoder_only/vision_language/test_intern_vit.py
View file @
217ee621
...
...
@@ -4,7 +4,7 @@ import os
import
pytest
import
torch
import
torch.nn
as
nn
from
huggingface_hub
import
snapshot_download
#
from huggingface_hub import snapshot_download
from
transformers
import
AutoConfig
,
AutoModel
,
CLIPImageProcessor
from
....conftest
import
_ImageAssets
,
cleanup
...
...
@@ -14,10 +14,8 @@ from ....utils import models_path_prefix
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN
=
[
"*.json"
,
"*.py"
,
"*.safetensors"
,
"*.txt"
,
"*.model"
]
models
=
[
snapshot_download
(
os
.
path
.
join
(
models_path_prefix
,
"OpenGVLab/InternViT-300M-448px"
),
allow_patterns
=
DOWNLOAD_PATTERN
),
snapshot_download
(
os
.
path
.
join
(
models_path_prefix
,
"OpenGVLab/InternViT-6B-448px-V1-5"
),
allow_patterns
=
DOWNLOAD_PATTERN
),
os
.
path
.
join
(
models_path_prefix
,
"OpenGVLab/InternViT-300M-448px"
),
os
.
path
.
join
(
models_path_prefix
,
"OpenGVLab/InternViT-6B-448px-V1-5"
),
]
...
...
tests/multimodal/test_utils.py
View file @
217ee621
...
...
@@ -5,12 +5,13 @@ from typing import Dict, Tuple
import
numpy
as
np
import
pytest
import
os
from
PIL
import
Image
from
transformers
import
AutoConfig
,
AutoTokenizer
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
repeat_and_pad_placeholder_tokens
)
from
..utils
import
urls_port
from
..utils
import
models_path_prefix
,
urls_port
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
...
...
@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert
_image_equals
(
data_image_sync
,
data_image_async
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"llava-hf/llava-v1.6-mistral-7b-hf"
)
])
def
test_repeat_and_pad_placeholder_tokens
(
model
):
config
=
AutoConfig
.
from_pretrained
(
model
)
image_token_id
=
config
.
image_token_index
...
...
tests/tensorizer_loader/test_tensorizer.py
View file @
217ee621
...
...
@@ -88,7 +88,7 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_can_deserialize_s3
(
vllm_runner
):
model_ref
=
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/pythia-1.4b"
)
tensorized_path
=
f
"
s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
tensorized_path
=
f
"
{
model_ref
}
/fp16/model.tensors"
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
...
...
@@ -341,7 +341,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
def
test_tensorizer_with_tp_path_without_template
(
vllm_runner
):
with
pytest
.
raises
(
ValueError
):
model_ref
=
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/pythia-1.4b"
)
tensorized_path
=
f
"
s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
tensorized_path
=
f
"
{
model_ref
}
/fp16/model.tensors"
vllm_runner
(
model_ref
,
...
...
tests/tokenization/test_tokenizer.py
View file @
217ee621
...
...
@@ -5,9 +5,15 @@ from transformers import PreTrainedTokenizerBase
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
..utils
import
models_path_prefix
# TOKENIZER_NAMES = [
# os.path.join(models_path_prefix, "facebook/opt-125m"),
# os.path.join(models_path_prefix, "gpt2"),
# ]
# export HF_ENDPOINT=https://hf-mirror.com
TOKENIZER_NAMES
=
[
os
.
path
.
join
(
models_path_prefix
,
"facebook/opt-125m"
)
,
os
.
path
.
join
(
models_path_prefix
,
"gpt2"
)
,
"facebook/opt-125m"
,
"gpt2"
,
]
...
...
tests/utils.py
View file @
217ee621
...
...
@@ -23,7 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
FlexibleArgumentParser
,
cuda_device_count_stateless
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
GB_bytes
,
cuda_device_count_stateless
,
get_open_port
,
is_hip
)
import
vllm.envs
as
envs
import
os
...
...
@@ -459,6 +459,36 @@ def fork_new_process_for_each_test(
return
wrapper
def
large_gpu_test
(
*
,
min_gb
:
int
):
"""
Decorate a test to be skipped if no GPU is available or it does not have
sufficient memory.
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
"""
try
:
if
current_platform
.
is_cpu
():
memory_gb
=
0
else
:
memory_gb
=
current_platform
.
get_device_total_memory
()
/
GB_bytes
except
Exception
as
e
:
warnings
.
warn
(
f
"An error occurred when finding the available memory:
{
e
}
"
,
stacklevel
=
2
,
)
memory_gb
=
0
test_skipif
=
pytest
.
mark
.
skipif
(
memory_gb
<
min_gb
,
reason
=
f
"Need at least
{
memory_gb
}
GB GPU memory to run the test."
,
)
def
wrapper
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
return
test_skipif
(
fork_new_process_for_each_test
(
f
))
return
wrapper
def
multi_gpu_test
(
*
,
num_gpus
:
int
):
"""
...
...
vllm/__init__.py
View file @
217ee621
...
...
@@ -11,7 +11,7 @@ from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__version__
,
__version_tuple__
,
__
d
cu_version__
from
vllm.version
import
__version__
,
__version_tuple__
,
__
h
cu_version__
__all__
=
[
...
...
vllm/_custom_ops.py
View file @
217ee621
...
...
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
try
:
from
lmslim
import
quant_ops
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq
or w8a8
model.
\n
"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -706,9 +706,9 @@ def cutlass_scaled_mm(a: torch.Tensor,
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
...
...
vllm/attention/selector.py
View file @
217ee621
...
...
@@ -207,7 +207,7 @@ def which_attn_to_use(
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in
AMD GPU
s."
,
selected_backend
)
logger
.
info
(
"%s is not supported in
hcu
s."
,
selected_backend
)
return
_Backend
.
ROCM_FLASH
# FlashAttn in NVIDIA GPUs.
...
...
vllm/benchmarks/benchmark_throughput.py
View file @
217ee621
...
...
@@ -522,7 +522,7 @@ if __name__ == "__main__":
default
=
"auto"
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (
AMD GPU
) supports fp8 (=fp8_e4m3)'
)
'ROCm (
hcu
) supports fp8 (=fp8_e4m3)'
)
parser
.
add_argument
(
'--quantization-param-path'
,
type
=
str
,
...
...
@@ -531,7 +531,7 @@ if __name__ == "__main__":
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (
AMD GPU
), FP8_E4M3 is '
'cuda version greater than 11.8. On ROCm (
hcu
), FP8_E4M3 is '
'instead supported for common inference criteria.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
...
...
vllm/config.py
View file @
217ee621
...
...
@@ -936,7 +936,7 @@ class ParallelConfig:
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"supported on
AMD GPU
s."
)
"supported on
hcu
s."
)
if
self
.
ray_workers_use_nsight
and
not
self
.
use_ray
:
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
...
...
vllm/distributed/device_communicators/pynccl_wrapper.py
View file @
217ee621
...
...
@@ -195,7 +195,7 @@ class NCCLLibrary:
except
Exception
as
e
:
logger
.
error
(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/
AMD GPU
s."
"It is expected if you are not running on NVIDIA/
hcu
s."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
...
...
vllm/engine/arg_utils.py
View file @
217ee621
...
...
@@ -294,7 +294,7 @@ class EngineArgs:
default
=
EngineArgs
.
kv_cache_dtype
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (
AMD GPU
) supports fp8 (=fp8_e4m3)'
)
'ROCm (
hcu
) supports fp8 (=fp8_e4m3)'
)
parser
.
add_argument
(
'--quantization-param-path'
,
type
=
nullable_str
,
...
...
@@ -304,7 +304,7 @@ class EngineArgs:
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (
AMD GPU
), FP8_E4M3 is instead '
'greater than 11.8. On ROCm (
hcu
), FP8_E4M3 is instead '
'supported for common inference criteria.'
)
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
217ee621
...
...
@@ -4,12 +4,12 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
...
...
@@ -200,12 +200,37 @@ def apply_int8_linear(
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
if
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
weight
.
shape
[
1
]
if
f
"
{
m
}
_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m
}
_
{
n
}
_
{
k
}
"
]
#print("json files:",best_config)
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
if
m
<
64
:
m_
=
32
elif
m
<
128
:
m_
=
64
elif
m
<
256
:
m_
=
128
elif
m
<
512
:
m_
=
256
elif
m
<
1024
:
m_
=
512
else
:
m_
=
1024
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
print
(
"config not found!"
)
return
ops
.
triton_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
bias
=
bias
,
best_config
=
best_config
)
elif
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
217ee621
...
...
@@ -23,14 +23,14 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
architectures
==
[
'BloomForCausalLM'
]
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
if
(
architectures
==
[
'BloomForCausalLM'
]
or
architectures
==
[
'FalconForCausalLM'
])
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
os
.
environ
[
'LM_NN'
]
=
'0'
else
:
os
.
environ
[
'LM_NN'
]
=
'1'
...
...
vllm/model_executor/models/falcon.py
View file @
217ee621
...
...
@@ -21,6 +21,8 @@
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
os
import
re
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
...
...
@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
RWConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
...
@@ -175,6 +180,11 @@ class FalconAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
self
,
...
...
@@ -184,6 +194,8 @@ class FalconAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
bias
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
if
bias
is
not
None
:
qkv
+=
bias
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module):
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
config
=
config
if
(
not
hasattr
(
config
,
"num_ln_in_parallel_attn"
)):
config
.
num_ln_in_parallel_attn
=
None
if
(
config
.
num_ln_in_parallel_attn
is
None
and
config
.
new_decoder_architecture
):
config
.
num_ln_in_parallel_attn
=
2
...
...
@@ -403,6 +418,17 @@ class FalconForCausalLM(nn.Module):
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
forward
(
self
,
...
...
@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attention.query_key_value.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_4h_to_h.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attention.query_key_value.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
\ No newline at end of file
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment