Commit 217ee621 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev

parents f0021a4d 3f78216a
...@@ -3,13 +3,22 @@ import torch ...@@ -3,13 +3,22 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols from vllm._custom_ops import permute_cols
from .utils import torch_version
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) @pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype): def test_permute_cols(shape, dtype):
if torch_version.startswith("2.3"):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
y = permute_cols(x, perm)
torch.allclose(y, x[:, perm])
elif torch_version.startswith("2.4"):
x = torch.randn(shape, dtype=dtype).cuda() x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm)) opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm) y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm]) torch.testing.assert_close(y, x[:, perm])
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
\ No newline at end of file
...@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( ...@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_aot_dispatch_dynamic", "test_aot_dispatch_dynamic",
) )
torch_version = torch.__version__
class QKVInputs(NamedTuple): class QKVInputs(NamedTuple):
''' '''
...@@ -974,9 +976,10 @@ def fp8_allclose( ...@@ -974,9 +976,10 @@ def fp8_allclose(
equal_nan=equal_nan)).item()) equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils if torch_version.startswith("2.4"):
# and a patched version of allclose that supports fp8 types. # A special version of op check that has a restricted default set of test_utils
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, # and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef], torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...], args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[Dict[str, Any]] = None,
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download # from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from ....conftest import _ImageAssets, cleanup from ....conftest import _ImageAssets, cleanup
...@@ -14,10 +14,8 @@ from ....utils import models_path_prefix ...@@ -14,10 +14,8 @@ from ....utils import models_path_prefix
# dynamic_module and trust_remote_code for hf_runner # dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
models = [ models = [
snapshot_download(os.path.join(models_path_prefix, "OpenGVLab/InternViT-300M-448px"), os.path.join(models_path_prefix, "OpenGVLab/InternViT-300M-448px"),
allow_patterns=DOWNLOAD_PATTERN), os.path.join(models_path_prefix, "OpenGVLab/InternViT-6B-448px-V1-5"),
snapshot_download(os.path.join(models_path_prefix, "OpenGVLab/InternViT-6B-448px-V1-5"),
allow_patterns=DOWNLOAD_PATTERN),
] ]
......
...@@ -5,12 +5,13 @@ from typing import Dict, Tuple ...@@ -5,12 +5,13 @@ from typing import Dict, Tuple
import numpy as np import numpy as np
import pytest import pytest
import os
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image, from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from ..utils import urls_port from ..utils import models_path_prefix, urls_port
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
...@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], ...@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert _image_equals(data_image_sync, data_image_async) assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf")])
def test_repeat_and_pad_placeholder_tokens(model): def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index image_token_id = config.image_token_index
......
...@@ -88,7 +88,7 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): ...@@ -88,7 +88,7 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner): def test_can_deserialize_s3(vllm_runner):
model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b") model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
...@@ -341,7 +341,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner): ...@@ -341,7 +341,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
def test_tensorizer_with_tp_path_without_template(vllm_runner): def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError): with pytest.raises(ValueError):
model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b") model_ref = os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"{model_ref}/fp16/model.tensors"
vllm_runner( vllm_runner(
model_ref, model_ref,
......
...@@ -5,9 +5,15 @@ from transformers import PreTrainedTokenizerBase ...@@ -5,9 +5,15 @@ from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ..utils import models_path_prefix from ..utils import models_path_prefix
# TOKENIZER_NAMES = [
# os.path.join(models_path_prefix, "facebook/opt-125m"),
# os.path.join(models_path_prefix, "gpt2"),
# ]
# export HF_ENDPOINT=https://hf-mirror.com
TOKENIZER_NAMES = [ TOKENIZER_NAMES = [
os.path.join(models_path_prefix, "facebook/opt-125m"), "facebook/opt-125m",
os.path.join(models_path_prefix, "gpt2"), "gpt2",
] ]
......
...@@ -23,7 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -23,7 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless, from vllm.utils import (FlexibleArgumentParser, GB_bytes,cuda_device_count_stateless,
get_open_port, is_hip) get_open_port, is_hip)
import vllm.envs as envs import vllm.envs as envs
import os import os
...@@ -459,6 +459,36 @@ def fork_new_process_for_each_test( ...@@ -459,6 +459,36 @@ def fork_new_process_for_each_test(
return wrapper return wrapper
def large_gpu_test(*, min_gb: int):
"""
Decorate a test to be skipped if no GPU is available or it does not have
sufficient memory.
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
"""
try:
if current_platform.is_cpu():
memory_gb = 0
else:
memory_gb = current_platform.get_device_total_memory() / GB_bytes
except Exception as e:
warnings.warn(
f"An error occurred when finding the available memory: {e}",
stacklevel=2,
)
memory_gb = 0
test_skipif = pytest.mark.skipif(
memory_gb < min_gb,
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_skipif(fork_new_process_for_each_test(f))
return wrapper
def multi_gpu_test(*, num_gpus: int): def multi_gpu_test(*, num_gpus: int):
""" """
......
...@@ -11,7 +11,7 @@ from vllm.outputs import (CompletionOutput, EmbeddingOutput, ...@@ -11,7 +11,7 @@ from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.version import __version__, __version_tuple__, __dcu_version__ from vllm.version import __version__, __version_tuple__, __hcu_version__
__all__ = [ __all__ = [
......
...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform ...@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
try: try:
from lmslim import quant_ops from lmslim import quant_ops
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -706,9 +706,9 @@ def cutlass_scaled_mm(a: torch.Tensor, ...@@ -706,9 +706,9 @@ def cutlass_scaled_mm(a: torch.Tensor,
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out # return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor, def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
......
...@@ -207,7 +207,7 @@ def which_attn_to_use( ...@@ -207,7 +207,7 @@ def which_attn_to_use(
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
logger.info("%s is not supported in AMD GPUs.", selected_backend) logger.info("%s is not supported in hcus.", selected_backend)
return _Backend.ROCM_FLASH return _Backend.ROCM_FLASH
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
......
...@@ -522,7 +522,7 @@ if __name__ == "__main__": ...@@ -522,7 +522,7 @@ if __name__ == "__main__":
default="auto", default="auto",
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') 'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,
...@@ -531,7 +531,7 @@ if __name__ == "__main__": ...@@ -531,7 +531,7 @@ if __name__ == "__main__":
'This should generally be supplied, when KV cache dtype is FP8. ' 'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause ' 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' 'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.') 'instead supported for common inference criteria.')
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
......
...@@ -936,7 +936,7 @@ class ParallelConfig: ...@@ -936,7 +936,7 @@ class ParallelConfig:
self.disable_custom_all_reduce = True self.disable_custom_all_reduce = True
logger.info( logger.info(
"Disabled the custom all-reduce kernel because it is not " "Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.") "supported on hcus.")
if self.ray_workers_use_nsight and not self.use_ray: if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers " raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.") "run with Ray.")
......
...@@ -195,7 +195,7 @@ class NCCLLibrary: ...@@ -195,7 +195,7 @@ class NCCLLibrary:
except Exception as e: except Exception as e:
logger.error( logger.error(
"Failed to load NCCL library from %s ." "Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs." "It is expected if you are not running on NVIDIA/hcus."
"Otherwise, the nccl library might not exist, be corrupted " "Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s." "or it does not support the current platform %s."
"If you already have the library, please set the " "If you already have the library, please set the "
......
...@@ -294,7 +294,7 @@ class EngineArgs: ...@@ -294,7 +294,7 @@ class EngineArgs:
default=EngineArgs.kv_cache_dtype, default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') 'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=nullable_str, type=nullable_str,
...@@ -304,7 +304,7 @@ class EngineArgs: ...@@ -304,7 +304,7 @@ class EngineArgs:
'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. ' 'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version' 'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' 'greater than 11.8. On ROCm (hcu), FP8_E4M3 is instead '
'supported for common inference criteria.') 'supported for common inference criteria.')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
......
...@@ -4,12 +4,12 @@ import torch ...@@ -4,12 +4,12 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip from vllm.utils import is_hip,W8a8GetCacheJSON
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
W8A8_TRITONJSON=W8a8GetCacheJSON()
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm # cutlass is not supported on Rocm
...@@ -200,12 +200,37 @@ def apply_int8_linear( ...@@ -200,12 +200,37 @@ def apply_int8_linear(
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
if w8a8_strategy==1: if w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=weight.shape[1]
if f"{m}_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
#print("json files:",best_config)
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
if m<64:
m_= 32
elif m<128:
m_=64
elif m<256:
m_=128
elif m<512:
m_=256
elif m<1024:
m_=512
else:
m_=1024
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{n}_{k}"]
else:
best_config=None
print("config not found!")
return ops.triton_scaled_mm(x_q, return ops.triton_scaled_mm(x_q,
weight, weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=bias) bias=bias,best_config=best_config)
elif w8a8_strategy==2: elif w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q, return ops.cutlass_scaled_mm(x_q,
weight, weight,
......
...@@ -23,14 +23,14 @@ def get_model_architecture( ...@@ -23,14 +23,14 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM'] or os.getenv('LM_NN') == '0': if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import os
import re
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
...@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -176,6 +181,11 @@ class FalconAttention(nn.Module): ...@@ -176,6 +181,11 @@ class FalconAttention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -184,6 +194,8 @@ class FalconAttention(nn.Module): ...@@ -184,6 +194,8 @@ class FalconAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
if bias is not None: if bias is not None:
qkv += bias qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module): ...@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module):
self.mlp = FalconMLP(config, quant_config) self.mlp = FalconMLP(config, quant_config)
self.config = config self.config = config
if (not hasattr(config, "num_ln_in_parallel_attn")):
config.num_ln_in_parallel_attn = None
if (config.num_ln_in_parallel_attn is None if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture): and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2 config.num_ln_in_parallel_attn = 2
...@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module): ...@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
...@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module): ...@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None :
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight",
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attention.query_key_value.weight"]
qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment