Unverified Commit 611a5a80 authored by Xu Kai's avatar Xu Kai Committed by GitHub
Browse files

[inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license
parent a0684e7b
...@@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. ...@@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
---------------- LICENSE FOR torch-int ----------------
MIT License
Copyright (c) 2022 Guangxuan Xiao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------- LICENSE FOR smoothquant ----------------
MIT License
Copyright (c) 2022 MIT HAN Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
try:
import torch_int
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
raise ImportError(
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)
if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
import os
import warnings
from abc import abstractmethod
from functools import partial
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union
import accelerate
import numpy as np
import torch
import torch.nn as nn
import transformers
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import PushToHubMixin, cached_file
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
SUPPORTED_MODELS = ["llama"]
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
layer_type: str = None
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__()
self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.config = self.model.config
self.cache_manager = None
self.max_total_token_num = 0
@property
def quantized(self):
return self._quantized
def init_cache_manager(self, max_total_token_num=2048):
if self.config.model_type == "llama":
head_num = self.config.num_key_value_heads
layer_num = self.config.num_hidden_layers
head_dim = self.config.hidden_size // head_num
self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
self.max_total_token_num = max_total_token_num
def init_batch_state(self, max_output_len=256, **kwargs):
input_ids = kwargs["input_ids"]
batch_size = len(input_ids)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
for i in range(batch_size):
seq_len = len(input_ids[i])
seq_lengths[i] = seq_len
seq_start_indexes[i] = start_index
start_index += seq_len
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
if "max_total_token_num" in kwargs.keys():
max_total_token_num = kwargs["max_total_token_num"]
self.init_cache_manager(max_total_token_num)
if "max_new_tokens" in kwargs.keys():
max_output_len = kwargs["max_new_tokens"]
if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
warnings.warn(f"reset max tokens to {max_total_token_num}")
self.init_cache_manager(max_total_token_num)
block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
batch_infer_state.cache_manager.free_all()
return batch_infer_state
@abstractmethod
@torch.inference_mode()
def quantize(
self,
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, **kwargs):
"""shortcut for model.generate"""
batch_infer_state = self.init_batch_state(**kwargs)
if self.config.model_type == "llama":
setattr(self.model.model, "infer_state", batch_infer_state)
with torch.inference_mode():
return self.model.generate(**kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
for text in tqdm(dataset):
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
pbar = tqdm(dataset)
for text in pbar:
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}
def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
for h in hooks:
h.remove()
return act_scales
@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
fcs = [fcs]
for fc in fcs:
assert isinstance(fc, nn.Linear)
assert ln.weight.numel() == fc.in_features == act_scales.numel()
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
act_scales = act_scales.to(device=device, dtype=dtype)
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
ln.weight.div_(scales)
if hasattr(ln, "bias"):
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
@classmethod
def create_quantized_model(model):
raise NotImplementedError("Not implement create_quantized_model method")
def save_quantized(
self,
save_dir: str,
model_basename: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to("cpu")
model_base_name = model_basename # or f"smooth-"
if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
print(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
print(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
print(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
else:
model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
self.model.config.save_pretrained(save_dir)
def save_pretrained(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs,
):
"""alias of save_quantized"""
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
torch_dtype: torch.dtype = torch.float16,
**model_init_kwargs,
):
if not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# Parameters related to loading from Hugging Face Hub
cache_dir = model_init_kwargs.pop("cache_dir", None)
force_download = model_init_kwargs.pop("force_download", False)
resume_download = model_init_kwargs.pop("resume_download", False)
proxies = model_init_kwargs.pop("proxies", None)
local_files_only = model_init_kwargs.pop("local_files_only", False)
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
revision = model_init_kwargs.pop("revision", None)
subfolder = model_init_kwargs.pop("subfolder", "")
model_init_kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
}
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified
model_init_kwargs["torch_dtype"] = torch_dtype
model_init_kwargs["trust_remote_code"] = trust_remote_code
if max_memory:
if "disk" in max_memory:
raise NotImplementedError("disk offload not support yet.")
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.tie_weights()
max_memory = accelerate.utils.get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
low_zero=False,
)
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
)
model_init_kwargs["low_cpu_mem_usage"] = True
del model
else:
model_init_kwargs["device_map"] = None
model_init_kwargs["low_cpu_mem_usage"] = False
torch.cuda.empty_cache()
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
return cls(model, False)
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
model_basename: Optional[str] = None,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
torch_dtype: Optional[torch.dtype] = None,
use_safetensors: bool = False,
trust_remote_code: bool = False,
**kwargs,
):
"""load quantized model from local disk"""
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
# == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
extensions = []
if use_safetensors:
extensions.append(".safetensors")
else:
extensions += [".bin", ".pt"]
model_name_or_path = str(model_name_or_path)
is_local = isdir(model_name_or_path)
resolved_archive_file = None
if is_local:
model_save_name = join(model_name_or_path, model_basename)
for ext in extensions:
if isfile(model_save_name + ext):
resolved_archive_file = model_save_name + ext
break
else: # remote
for ext in extensions:
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
if resolved_archive_file is not None:
break
if resolved_archive_file is None: # Could not find a model file to use
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
model_save_name = resolved_archive_file
# == step2: convert model to quantized-model (replace Linear) == #
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
init_contexts = [no_init_weights()]
if low_cpu_mem_usage:
init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
with ContextManagers(init_contexts):
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
)
cls.create_quantized_model(model)
model.tie_weights()
# == step3: load checkpoint to quantized-model == #
accelerate.utils.modeling.load_checkpoint_in_model(
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
)
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
return cls(
model,
True,
)
def __getattr__(self, item):
try:
return super().__getattr__(item)
except:
return getattr(self.model, item)
__all__ = ["BaseSmoothForCausalLM"]
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
import torch
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self
@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
y = y.view(*x_shape[:-1], -1)
return y
@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
if module.bias is not None:
int8_module.bias.data.copy_(module.bias.to(torch.float))
int8_module.a = alpha
return int8_module
class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
self.register_buffer("b", torch.tensor(beta))
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self
@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
y = y.view(*x_shape[:-1], -1)
return y
@staticmethod
def from_float(module: torch.nn.Linear, input_scale, output_scale):
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale / output_scale
int8_module.weight = int8_weight
int8_module.a = alpha
if module.bias is not None:
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
int8_module.bias = int8_bias
beta = bias_scale / output_scale
int8_module.b = beta
return int8_module
class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
self.bias = self.bias.to(torch.float32)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self
@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
y = y.view(*x_shape[:-1], -1)
return y
@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
int8_module.a = alpha
int8_module.input_scale = input_scale
int8_module.weight_scale = weight_scale
if module.bias is not None:
int8_module.bias = module.bias.to(torch.float32)
return int8_module
This diff is collapsed.
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = float;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
#if CUDA_ARCH >= 800
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
#elif CUDA_ARCH >= 750
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
EpilogueOp>;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
#else
#error "Unsupported cuda arch"
#endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{alpha, beta}, 1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out = torch::silu(out);
#endif
return out;
}
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);
...@@ -13,8 +13,10 @@ if HAS_TRITON: ...@@ -13,8 +13,10 @@ if HAS_TRITON:
from .copy_kv_cache_dest import copy_kv_cache_to_dest from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .rms_norm import rmsnorm_forward from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd from .rotary_embedding_kernel import rotary_embedding_fwd
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
from .softmax import softmax from .softmax import softmax
from .token_attention_kernel import token_attention_fwd from .token_attention_kernel import token_attention_fwd
...@@ -29,4 +31,7 @@ if HAS_TRITON: ...@@ -29,4 +31,7 @@ if HAS_TRITON:
"rotary_embedding_fwd", "rotary_embedding_fwd",
"token_attention_fwd", "token_attention_fwd",
"gptq_fused_linear_triton", "gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
"smooth_llama_context_attn_fwd",
"smooth_token_attention_fwd",
] ]
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl
@triton.jit
def _rotary_kernel(
q,
input_scale,
output_scale,
Cos,
Sin,
q_bs_stride,
q_h_stride,
q_d_stride,
cos_bs_stride,
cos_d_stride,
total_len,
HEAD_NUM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
current_head_index = tl.program_id(0)
current_seq_index = tl.program_id(1)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
q0 = q0.to(tl.float32) * input_scale
q1 = q1.to(tl.float32) * input_scale
out0 = (q0 * cos - q1 * sin) / output_scale
out1 = (q0 * sin + q1 * cos) / output_scale
out0 = out0.to(tl.int8)
out1 = out1.to(tl.int8)
tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
return
@torch.no_grad()
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4
_rotary_kernel[grid](
q,
input_scale,
output_scale,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
total_len,
HEAD_NUM=head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
HEAD_DIM=head_dim,
num_warps=num_warps,
num_stages=1,
)
return
This diff is collapsed.
import argparse
import os
import torch
from datasets import load_dataset
from transformers import LlamaTokenizer
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
def build_model_and_tokenizer(model_name):
tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512)
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs)
model = model.to(torch.float32)
return model, tokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, help="model name")
parser.add_argument(
"--output-path",
type=str,
help="where to save the checkpoint",
)
parser.add_argument(
"--dataset-path",
type=str,
help="location of the calibration dataset",
)
parser.add_argument("--num-samples", type=int, default=512)
parser.add_argument("--seq-len", type=int, default=512)
args = parser.parse_args()
return args
@torch.no_grad()
def main():
args = parse_args()
model_path = args.model_name
dataset_path = args.dataset_path
output_path = args.output_path
num_samples = 10
seq_len = 512
model, tokenizer = build_model_and_tokenizer(model_path)
if not os.path.exists(dataset_path):
print(f"Cannot find the dataset at {args.dataset_path}")
raise FileNotFoundError
dataset = load_dataset("json", data_files=dataset_path, split="train")
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
model = model.cuda()
model.save_quantized(output_path, model_basename="llama-7b")
model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
model = model.cuda()
generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
out = model.generate(**input_tokens, **generate_kwargs)
text = tokenizer.batch_decode(out)
print("out is:", text)
if __name__ == "__main__":
main()
import torch
from .builder import Builder
from .utils import append_nvcc_threads
class SmoothquantBuilder(Builder):
NAME = "cu_smoothquant"
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant"
def __init__(self):
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()]
return ret
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"smoothquant/binding.cpp",
"smoothquant/linear.cu",
]
]
return ret
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
compute_capability = torch.cuda.get_device_capability()
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10
extra_cuda_flags = [
"-v",
f"-DCUDA_ARCH={cuda_arch}",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)
def builder(self):
try:
super().builder()
except:
warnings.warn("build smoothquant lib not successful")
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
import math
import torch
from torch.nn import functional as F
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
"""
adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq = xq.view(bs, seqlen, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
mask[mask == 0.0] = -100000000.0
mask = mask.repeat(bs, num_head, 1, 1)
keys = xk
values = xv
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
)
def test_llama_context_attention():
head_num = 2
seq_len = 32
head_dim = 64
dtype = torch.float
hidden_size = head_num * head_dim
smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)
smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8)
qkv_weight_scale = 1.0
ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda")
smooth_attn = smooth_attn.to("cuda")
input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")
input_scale = 1 / 20.0
output = torch.matmul(input.to(torch.float) * input_scale, ones)
qkv_max_out = torch.max(torch.abs(output)) / 127
smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
q = smooth_attn.q_proj(input)
k = smooth_attn.k_proj(input)
v = smooth_attn.v_proj(input)
cos_shape = (seq_len, head_dim // 2)
cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")
in_scale = torch.tensor([qkv_max_out], device="cuda")
out_scale = torch.tensor([qkv_max_out], device="cuda")
int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
q = q.to(torch.float) * out_scale
k = k.to(torch.float) * out_scale
v = v.to(torch.float) * out_scale
torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
attn_out_max = torch.max(torch.abs(torch_out)) / 127
output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones)
smooth_attn.q_output_scale = torch.tensor(qkv_max_out)
smooth_attn.k_output_scale = torch.tensor(qkv_max_out)
smooth_attn.v_output_scale = torch.tensor(qkv_max_out)
smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out)
smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out)
smooth_attn.attn_output_scale = torch.tensor(attn_out_max)
smooth_attn.out_proj.a = torch.tensor([attn_out_max])
torch_out = (
(torch_out / smooth_attn.attn_output_scale)
.round()
.clamp(-128, 127)
.to(torch.int8)
.view(-1, seq_len, head_num * head_dim)
)
torch_out = smooth_attn.out_proj(torch_out)
torch_out = torch_out.to(torch.float)
smooth_attn = smooth_attn.to("cuda")
smooth_out, _, _ = smooth_attn(input, (cos, sin))
smooth_out = smooth_out.to(torch.float)
assert torch.allclose(
torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1
), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_llama_context_attention()
import warnings
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False
try:
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP
HAS_TORCH_INT = True
except:
HAS_TORCH_INT = False
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
gate_out = torch.mm(x, gate_proj)
silu = torch.nn.SiLU()
gate_out = silu(gate_out)
up_out = torch.mm(x, up_proj)
o_out = gate_out * up_out
max_up = torch.max(torch.abs(o_out))
min_up = torch.min(torch.abs(o_out))
torch_out = torch.mm(o_out, down_proj)
return (torch_out, max_up, min_up)
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
reason="smoothquant linear not installed properly or not install torch_int",
)
def test_llama_mlp():
hidden_size = 256
intermediate_size = 512
smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)
smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")
smooth_mlp.up_proj.weight = torch.randint(
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
)
smooth_mlp.down_proj.weight = torch.randint(
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
)
x = torch.ones((1, 256), dtype=torch.int8, device="cuda")
torch_out, max_inter, min_inter = torch_llama_mlp(
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
x.to(torch.float),
)
smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127)
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))
smooth_out = smooth_mlp(x)
assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)
if __name__ == "__main__":
test_llama_mlp()
import warnings
import pytest
import torch
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False
@pytest.mark.skipif(
not HAS_SMOOTHQUANT_CUDA,
reason="smoothquant linear not installed properly",
)
def test_linear():
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
c = torch.rand(256, dtype=torch.float, device="cuda")
alpha = 1 / 127
beta = 1.0
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
silu = torch.nn.SiLU()
torch_out = silu(torch_out)
b = b.transpose(0, 1).contiguous()
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
if __name__ == "__main__":
test_linear()
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_rotary_emb():
SEQ_LEN = 1
HEAD_NUM = 32
HEAD_DIM = 128
dtype = torch.float
# create data
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
# forward pass
y_torch = torch_rotary_emb(x, cos, sin)
input_scale = torch.max(torch.abs(x)) / 127
output_scale = torch.max(torch.abs(y_torch)) / 127
x = x / input_scale
x = x.to(torch.int8)
int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item())
y_triton = x.to(torch.float) * output_scale
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
if __name__ == "__main__":
test_rotary_emb()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment