Unverified Commit 3fff964d authored by pppppM's avatar pppppM Committed by GitHub
Browse files

[Feature] Stats Quantization Parameters for KV Cache (#45)

* add cal qparams

* support offload inference

* add collect funtions (mod,weight)

* stats kv scales

* update init

* add user guide

* fix hints

* fix comments & support turbomind format

* update user guide

* fix slice kv cache error & support pileval dataset (used in llm-awq)

* fix wrong num heads slice

* update default dataset

* fix conflict

* fix hints

* fix hints

* add gitignore
parent edb6eb86
......@@ -8,4 +8,5 @@ workspace/
lmdeploy/lib/
dist/
examples/cpp/llama/*.csv
*.npy
*.weight
......@@ -153,6 +153,20 @@ python3 lmdeploy/app.py {server_ip_addresss}:33337 {model_name}
In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.
First execute the quantization script, and the quantization parameters are stored in the weight directory transformed by `deploy.py`.
```
python3 -m lmdeploy.lite.apis.kv_qparams \
--model $HF_MODEL \
--output_dir $DEPLOY_WEIGHT_DIR \
--symmetry True \ # Whether to use symmetric or asymmetric quantization.
--offload False \ # Whether to offload some modules to CPU to save GPU memory.
--num_tp 1 \ # The number of GPUs used for tensor parallelism
```
Then adjust `config.ini`
- `use_context_fmha` changed to 0, means off
......
......@@ -50,7 +50,6 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
![PersistentBatchInference](https://github.com/open-mmlab/lmdeploy/assets/25839884/8f8b57b8-42af-4b71-ad74-e75f39b10694)
## 快速上手
### 安装
......@@ -151,6 +150,16 @@ python3 lmdeploy/app.py {server_ip_addresss}:33337 {model_name}
在 fp16 模式下,可以开启 kv_cache int8 量化,单卡可服务更多用户。
首先执行量化脚本,量化参数存放到 `deploy.py` 转换的 weight 目录下。
```
python3 -m lmdeploy.lite.apis.kv_qparams \
--model $HF_MODEL \
--output_dir $DEPLOY_WEIGHT_DIR \
--symmetry True \ # 对称量化或非对称量化,默认为 True
--offload False \ # 将模型放在 CPU,只在推理时加载部分模块到 GPU,默认为 False
--num_tp 1 \ # Tensor 并行使用的 GPU 数,和 deploy.py 保持一致
```
然后调整 `config.ini`
- `use_context_fmha` 改为 0,表示关闭
......
# import multiprocessing as mp
from threading import Thread
from queue import Queue
import time
from queue import Queue
from threading import Thread
import fire
import numpy as np
from transformers import AutoTokenizer
from lmdeploy.turbomind import TurboMind
from lmdeploy.model import MODELS
from transformers import AutoTokenizer
from lmdeploy.turbomind import TurboMind
def infer(model, session_id: int, input_ids: str, output_seqlen: int,
......@@ -19,13 +19,12 @@ def infer(model, session_id: int, input_ids: str, output_seqlen: int,
start = time.perf_counter()
timestamps = [start]
tokens = [0]
for outputs in chatbot.stream_infer(
session_id,
input_ids,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
ignore_eos=True):
for outputs in chatbot.stream_infer(session_id,
input_ids,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
ignore_eos=True):
res, token = outputs[0]
timestamps.append(time.perf_counter())
tokens.append(token)
......@@ -48,13 +47,12 @@ def warmup(model,
def _infer(model, session_id):
chatbot = model.create_instance()
for _ in range(warmup_round):
for _ in chatbot.stream_infer(
session_id,
input_ids=[1],
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
ignore_eos=True):
for _ in chatbot.stream_infer(session_id,
input_ids=[1],
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
ignore_eos=True):
continue
_start = time.perf_counter()
......@@ -88,8 +86,7 @@ def main(model_path: str,
stop_words = model.stop_words
tm_model = TurboMind(model_path=model_path, stop_words=stop_words)
warmup(tm_model, concurrency, session_len,
output_seqlen)
warmup(tm_model, concurrency, session_len, output_seqlen)
# make up a prompt that can be tokenized into {input_seqlen} tokens
prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1)
......@@ -101,8 +98,8 @@ def main(model_path: str,
# TODO: update to the multithread version
for i in range(concurrency):
proc = Thread(target=infer,
args=(tm_model, i + 1, input_ids, output_seqlen,
test_round, que))
args=(tm_model, i + 1, input_ids, output_seqlen,
test_round, que))
procs.append(proc)
proc.start()
......
# PTQ 量化测试结果
## 显存测试
测试对象为 [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 指令模型。
测试方法:
1. 使用 `deploy.py` 转换模型,修改 `workspace` 配置中的最大并发数;调整 `llama_config.ini` 中的请求数
2. 编译执行 `bin/llama_triton_example`,获取 fp16 版本在不同 batch_size 的显存情况
3. 执行量化脚本,获取量化参数;修改配置文件,使 [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) 选项生效
......@@ -11,11 +13,11 @@
以下是两个版本的显存对比:
| batch_size | fp16 memory(MiB) | int8 memory(MiB) | diff(MiB) |
| :-: | :-: | :-: | :-: |
| 8 | 22337 | 18241 | -4096 |
| 16 | 30593 | 22369 | -8224 |
| 32 | 47073 | 30625 | -16448 |
| 48 | 63553 | 38881 | -24672 |
| :--------: | :--------------: | :--------------: | :-------: |
| 8 | 22337 | 18241 | -4096 |
| 16 | 30593 | 22369 | -8224 |
| 32 | 47073 | 30625 | -16448 |
| 48 | 63553 | 38881 | -24672 |
相对于直接量化 Weight(如 [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/)),我们做了两种方案在 7B 模型中的内存增长对比预估,部分数据来自 [llama.cpp](https://github.com/ggerganov/llama.cpp)
......@@ -29,6 +31,7 @@
测试对象为 [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 指令模型。
测试方法:
1.`deploy.py` 转换模型,运行 docker 服务
2. 通过 `client.py` 测试数据集,获取 fp16 版本精度
3. 执行量化脚本,得到量化参数,放到 weights 目录;修改配置文件,使 [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) 选项生效
......@@ -36,16 +39,16 @@
以下是 `kCacheKVInt8` 方法仅用 c4 数据集量化,在 mmlu-social-science 数据集的精度损失。
| task | dataset | metric | fp16 | int8 | diff |
| :-: | :-: | :-: | :-: | :-: | :-: |
| Exam | mmlu-social-science | score | 31.81 | 32.00 | +0.19 |
| task | dataset | metric | fp16 | int8 | diff |
| :--: | :-----------------: | :----: | :---: | :---: | :---: |
| Exam | mmlu-social-science | score | 31.81 | 32.00 | +0.19 |
我们注意到精度有轻微提升,mmlu-social-science 一共 3065 个选择题,具体差异如下:
| 类型 | 个数 |
| :-: | :-: |
| fp16 版回答错误,int8 版变对 | 72 |
| fp16 版回答正确,int8 版变错 | 66 |
| 两版均答错且答案不同 | 118 |
| 类型 | 个数 |
| :--------------------------: | :--: |
| fp16 版回答错误,int8 版变对 | 72 |
| fp16 版回答正确,int8 版变错 | 66 |
| 两版均答错且答案不同 | 118 |
我们已经在更大的模型上验证了更多数据集,将持续更新结果。
import random
import fire
from transformers import AutoTokenizer
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
from transformers import AutoTokenizer
import random
def input_prompt():
......@@ -46,8 +48,8 @@ def main(model_name, model_path, tokenizer_model_path, session_id: int = 1):
random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0]
# decode res
response = tokenizer.decode(
res[step:], skip_special_tokens=True)
response = tokenizer.decode(res[step:],
skip_special_tokens=True)
print(f'session {session_id}, {tokens}, {response}')
# update step
step = tokens - 1
......
# Copyright (c) OpenMMLab. All rights reserved.
from .apis import * # noqa: F401,F403
from .quantization import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import List, Tuple
import fire
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from transformers.models.llama.modeling_llama import (LlamaDecoderLayer,
LlamaForCausalLM)
from lmdeploy.lite.quantization import Observer
from lmdeploy.lite.utils import get_calib_loaders, memory_efficient_inference
# OFFLOAD_MOD_MAP is a dictionary that specifies which parts of
# certain model types should be offloaded to the CPU during inference.
# The key of this dictionary is a model class and the value is a tuple
# of modules within that model that should be offloaded.
# As an example, here it is specified that for the LlamaForCausalLM model,
# only the LlamaDecoderLayer should be offloaded. This might be because
# the LlamaDecoderLayer consumes a significant amount of GPU memory
# and offloading it when not in use can help save GPU resources.
OFFLOAD_MOD_MAP = {LlamaForCausalLM: (LlamaDecoderLayer, )}
def absmax(tensor: torch.Tensor) -> float:
"""Returns the maximum absolute value in a tensor.
Args:
tensor (torch.Tensor): Input tensor.
Returns:
float: Maximum absolute value in the tensor.
"""
return tensor.abs().max().item()
def minmax(tensor: torch.Tensor) -> Tuple[float, float]:
"""Returns the minimum and maximum value in a tensor.
Args:
tensor (torch.Tensor): Input tensor.
Returns:
tuple: Minimum and maximum value in the tensor.
"""
return (tensor.min().item(), tensor.max().item())
def stats_past_key_values(past_key_values: List[torch.Tensor],
k_obs_list: List[Observer],
v_obs_list: List[Observer], symmetry: bool,
num_tp: int) -> None:
"""Collects statistics for past key values.
Args:
past_key_values (List[Tensor]): Past key values generated by the
model during forward pass.
k_obs_list (List[Observer]): List of observers for collecting
stats for keys.
v_obs_list (List[Observer]): List of observers for collecting
stats for values.
symmetry (bool): Whether to use symmetric or asymmetric quantization.
"""
if len(k_obs_list) == 0 and len(v_obs_list) == 0:
num_layers = len(past_key_values)
for _ in range(num_layers * num_tp):
if symmetry:
k_observer = Observer(absmax)
v_observer = Observer(absmax)
else:
k_observer = Observer(minmax)
v_observer = Observer(minmax)
k_observer.enable_observer()
v_observer.enable_observer()
k_obs_list.append(k_observer)
v_obs_list.append(v_observer)
assert len(k_obs_list) == len(past_key_values) * num_tp
assert len(v_obs_list) == len(past_key_values) * num_tp
for layer, (k_cache, v_cache) in enumerate(past_key_values):
for tp in range(num_tp):
k_obs = k_obs_list[layer * num_tp + tp]
v_obs = v_obs_list[layer * num_tp + tp]
# K Cache Shape: [Bs, Heads, Tokens, Dims]
per_tp_heads = k_cache.size(1) // num_tp
k_obs(k_cache[:, tp * per_tp_heads:(tp + 1) * per_tp_heads])
v_obs(v_cache[:, tp * per_tp_heads:(tp + 1) * per_tp_heads])
def main(model: str,
bits: int = 8,
granularity: str = 'per_tensor',
symmetry: bool = True,
offload: bool = False,
max_seq_len: int = 2048,
num_tp: int = 1,
calib_dataset: str = 'c4',
calib_samples: int = 128,
output_dir: str = './kv_scales'):
assert granularity in ['per_tensor'], \
'Currently, only support per-tensor quantization for the kv cache.'
assert bits == 8, \
'Currently, only support 8-bit quantization for the kv cache.'
assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \
'Currently, only support `c4`, `ptb`, `wikitext2`, or `pileval`.'
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
model = AutoModel.from_pretrained(model)
model.use_cache = True
print('Loading calibrate dataset ...')
calib_loader, _ = get_calib_loaders(calib_dataset,
tokenizer,
nsamples=calib_samples,
seqlen=max_seq_len)
k_obs_list = list()
v_obs_list = list()
if offload:
import warnings
warnings.warn('You are using the `offload` mode, in which the '
'modules in the `OFFLOAD_MOD_MAP` will be moved to '
'the GPU during forward and kept on the CPU at other '
'times to save GPU memory.')
if type(model) not in OFFLOAD_MOD_MAP:
warnings.warn(f'{type(model)} is not in the `OFFLOAD_MOD_MAP`,'
f'and by default, offloading will be done on '
'`nn.Linear`. You can add more robust modules to '
'the `OFFLOAD_MOD_MAP` for faster speed.')
offload_mod = OFFLOAD_MOD_MAP[type(model)]
with memory_efficient_inference(model, offload_mod):
for data in tqdm(calib_loader, desc='Calibrating: '):
if isinstance(data, torch.Tensor):
output = model(data.to('cuda'))
else:
output = model(data[0].to('cuda'))
kv_cache = output.past_key_values
stats_past_key_values(kv_cache, k_obs_list, v_obs_list,
symmetry, num_tp)
else:
model.to('cuda')
with torch.inference_mode():
for data in tqdm(calib_loader, desc='Calibrating: '):
if isinstance(data, torch.Tensor):
output = model(data.to('cuda'))
else:
output = model(data[0].to('cuda'))
kv_cache = output.past_key_values
stats_past_key_values(kv_cache, k_obs_list, v_obs_list,
symmetry, num_tp)
import numpy as np
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for i, (k_obs, v_obs) in enumerate(zip(k_obs_list, v_obs_list)):
layer = i // num_tp
tp = i % num_tp
save_path = out_dir / f'layers.{layer}.past_kv_scale.{tp}.weight'
if symmetry:
k_scale = max(k_obs.buffer) / (2**(bits - 1) - 1)
v_scale = max(v_obs.buffer) / (2**(bits - 1) - 1)
kv_qparams = np.array([k_scale, v_scale], dtype=np.float32)
kv_qparams.tofile(save_path)
print(f'Layer {layer} TP {tp} KV scales done.')
else:
k_min = min([min_k for min_k, _ in k_obs.buffer])
k_max = max([max_k for _, max_k in k_obs.buffer])
v_min = min([min_v for min_v, _ in v_obs.buffer])
v_max = max([max_v for _, max_v in v_obs.buffer])
k_scale = (k_max - k_min) / (2**bits - 1)
v_scale = (v_max - v_min) / (2**bits - 1)
k_zero = (-k_min / k_scale).round()
v_zero = (-v_min / v_scale).round()
kv_qparams = np.array([k_scale, k_zero, v_scale, v_zero],
dtype=np.float32)
kv_qparams.tofile(save_path)
print(f'Layer {i} KV scales&zeros done.')
if __name__ == '__main__':
fire.Fire(main)
# Copyright (c) OpenMMLab. All rights reserved.
from .observer import Observer
__all__ = ['Observer']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Callable
class Observer:
"""The Observer class applies a user-specified function on its inputs and
stores the results in a buffer.
Args:
observe_fn (Callable[..., Any]): The function to apply on inputs.
"""
def __init__(self, observe_fn: Callable[..., Any]) -> None:
super().__init__()
self.fn = observe_fn
self.buffer = list()
self.enabled = False
def enable_observer(self, enabled: bool = True) -> None:
"""Enable or disable the observer.
Args:
enabled (bool, optional): Whether to enable the observer.
Defaults to True.
"""
self.enabled = enabled
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Apply the observer function on the input if the observer is enabled.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if self.enabled:
self.buffer.append(self.fn(*args, **kwargs))
# Copyright (c) OpenMMLab. All rights reserved.
from .cal_qparams import (cal_qparams_per_channel_absmax,
cal_qparams_per_channel_minmax,
cal_qparams_per_group_absmax,
cal_qparams_per_group_minmax,
cal_qparams_per_tensor_absmax,
cal_qparams_per_tensor_minmax)
from .calib_dataloader import get_calib_loaders
from .collect import collect_target_modules, collect_target_weights
from .memory_efficient import memory_efficient_inference
__all__ = [
'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax',
'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax',
'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax',
'get_calib_loaders', 'memory_efficient_inference',
'collect_target_modules', 'collect_target_weights'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import NamedTuple, Optional
import torch
class QParams(NamedTuple):
"""A class to hold the quantization parameters."""
scales: torch.Tensor
zero_points: Optional[torch.Tensor]
@torch.no_grad()
def cal_qparams_per_channel_absmax(w: torch.Tensor, n_bits: int) -> QParams:
"""Calculate quantization parameters for each channel using absolute max
value."""
scales = w.abs().max(dim=-1, keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1
scales = scales.clamp_(min=1e-5).div_(q_max)
return QParams(scales=scales, zero_points=None)
@torch.no_grad()
def cal_qparams_per_channel_minmax(w: torch.Tensor, n_bits: int) -> QParams:
"""Calculate quantization parameters for each channel using min and max
values."""
w_min = w.min(dim=-1, keepdim=True)[0]
w_max = w.max(dim=-1, keepdim=True)[0]
q_max = 2**n_bits - 1
scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max)
zero_points = (-w_min / scales).round()
return QParams(scales=scales, zero_points=zero_points)
@torch.no_grad()
def cal_qparams_per_group_absmax(w: torch.Tensor, n_bits: int,
group_size: int) -> QParams:
"""Calculate quantization parameters for each group using absolute max
value."""
outc, inc = w.shape
assert inc >= group_size, \
'Input channels should be greater than or equal to group_size.'
assert inc % group_size == 0, \
'Input channels should be divisible by group_size.'
scales = w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1
scales = scales.clamp_(min=1e-5).div_(q_max)
return QParams(scales=scales, zero_points=None)
@torch.no_grad()
def cal_qparams_per_group_minmax(w: torch.Tensor, n_bits: int,
group_size: int) -> QParams:
"""Calculate quantization parameters for each group using min and max
values."""
outc, inc = w.shape
assert inc >= group_size, \
'Input channels should be greater than or equal to group_size.'
assert inc % group_size == 0, \
'Input channels should be divisible by group_size.'
w_group_wise = w.reshape(outc, -1, group_size)
w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
q_max = 2**n_bits - 1
scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max)
zero_points = (-w_min / scales).round()
return QParams(scales=scales, zero_points=zero_points)
@torch.no_grad()
def cal_qparams_per_tensor_minmax(w: torch.Tensor, n_bits: int) -> QParams:
"""Calculate quantization parameters for the entire tensor using min and
max values."""
w_min = w.min()
w_max = w.max()
q_max = 2**n_bits - 1
scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max)
zero_points = (-w_min / scales).round()
return QParams(scales=scales, zero_points=zero_points)
@torch.no_grad()
def cal_qparams_per_tensor_absmax(w: torch.Tensor, n_bits: int) -> QParams:
"""Calculate quantization parameters for the entire tensor using absolute
max value."""
scales = w.abs().max()
q_max = 2**(n_bits - 1) - 1
scales = scales.clamp_(min=1e-5).div_(q_max)
return QParams(scales=scales, zero_points=None)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
def get_wikitext2(tokenizer, nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
trainenc = tokenizer('\n\n'.join(traindata['text']), return_tensors='pt')
testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_ptb(tokenizer, nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only',
'penn_treebank',
split='validation')
trainenc = tokenizer('\n\n'.join(traindata['sentence']),
return_tensors='pt')
testenc = tokenizer('\n\n'.join(valdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4(tokenizer, nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
split='train',
use_auth_token=False)
valdata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation',
use_auth_token=False)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
import random
random.seed(0)
valenc = []
for _ in range(256):
while True:
i = random.randint(0, len(valdata) - 1)
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
if tmp.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
valenc.append(tmp.input_ids[:, i:j])
valenc = torch.hstack(valenc)
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_ptb_new(tokenizer, nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4_new(tokenizer, nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
split='train')
valdata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_pileval(tokenizer, nsamples, seed, seqlen=512):
from datasets import load_dataset
from datasets.builder import DatasetGenerationError
try:
dataset = load_dataset(
'json',
data_files='https://the-eye.eu/public/AI/pile/val.jsonl.zst',
split='train')
except DatasetGenerationError:
raise InterruptedError('There have been some issues when generating '
'the dataset, you could try to download it '
'locally first, and replace the `data_files`'
'with local addresses or use other datasets '
'(c4, wiki, ptb).')
dataset = dataset.shuffle(seed=seed)
samples = []
n_run = 0
for data in dataset:
line = data['text']
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run += 1
if n_run == nsamples:
break
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // seqlen
print(f' * Split into {n_split} blocks')
return [
cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)
], None
def get_calib_loaders(name,
tokenizer,
nsamples=128,
seed=0,
seqlen=2048,
model=''):
if 'wikitext2' in name:
return get_wikitext2(tokenizer, nsamples, seed, seqlen, model)
if 'ptb' in name:
if 'new' in name:
return get_ptb_new(tokenizer, nsamples, seed, seqlen, model)
return get_ptb(tokenizer, nsamples, seed, seqlen, model)
if 'c4' in name:
if 'new' in name:
return get_c4_new(tokenizer, nsamples, seed, seqlen, model)
return get_c4(tokenizer, nsamples, seed, seqlen, model)
if 'pileval' in name:
return get_pileval(tokenizer, nsamples, seed, seqlen)
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
def collect_target_weights(model: nn.Module, target_module_types: type,
skip_modules: list) -> dict:
"""Collects target weight tensors in the model and returns them in a
dictionary.
Args:
model (nn.Module): Model containing the target modules.
target (type): Target module type, e.g., nn.Linear.
skip_modules (list): List of modules that should not be included in
the result.
Returns:
dict: A dictionary containing the target weight tensors in the model.
"""
target_weights = {}
for name, module in model.named_modules():
if isinstance(module,
target_module_types) and name not in skip_modules:
assert hasattr(module, 'weight')
target_weights[name] = module.weight
return target_weights
def collect_target_modules(model: nn.Module,
target_module_types: type,
skip_modules: list = []) -> dict:
"""Collects target weight tensors in the model and returns them in a
dictionary.
Args:
model (nn.Module): Model containing the target modules.
target (type): Target module type, e.g., nn.Linear.
skip_modules (list): List of modules that should not be included in
the result.
Returns:
dict: A dictionary containing the target modules in the model.
"""
target_modules = {}
for name, module in model.named_modules():
if isinstance(module,
target_module_types) and name not in skip_modules:
target_modules[name] = module
return target_modules
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
import torch
from torch import nn
@contextmanager
def memory_efficient_inference(model: nn.Module,
target=(nn.Linear, ),
device='cuda'):
"""Context manager for memory-efficient inference on specified modules of a
PyTorch model.
Args:
model (nn.Module): The model to be used for inference.
target (tuple): A tuple containing the target module classes to move to
GPU during forward pass.
device (str): The device ('cpu' or 'cuda') where the model will be
moved during inference.
Yields:
None
Example:
with memory_efficient_inference(model, target=nn.Linear, device='cuda'):
output = model(input)
"""
def _before_forward_hook(m, input):
m.to(device)
def _after_forward_hook(m, input, output):
m.to('cpu')
torch.cuda.empty_cache()
def _to_device(m, spec_modules, dev):
if len(spec_modules) == 0:
m.to(dev)
return
for child in m.children():
if isinstance(child, spec_modules):
child.to('cpu')
else:
_to_device(child, spec_modules, dev)
m.to(dev)
_to_device(model, target, device)
# enter
hook_handles = []
for module in model.modules():
if isinstance(module, target):
before_h = module.register_forward_pre_hook(_before_forward_hook)
after_h = module.register_forward_hook(_after_forward_hook)
hook_handles.append(before_h)
hook_handles.append(after_h)
with torch.inference_mode():
yield
# exit
for h in hook_handles:
h.remove()
model.to('cpu')
torch.cuda.empty_cache()
......@@ -322,7 +322,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
return _params[name]
def get_tensor_transposed(name: str):
if not name in _params and name.find('bias'):
if name not in _params and name.find('bias'):
return None
return _params[name].t()
......
# Copyright (c) OpenMMLab. All rights reserved.
from .tokenizer import Tokenizer, Preprocessor, Postprocessor
from .tokenizer import Postprocessor, Preprocessor, Tokenizer
from .turbomind import TurboMind
from .turbomind import TurboMind
\ No newline at end of file
__all__ = ['Postprocessor', 'Preprocessor', 'Tokenizer', 'TurboMind']
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Sequence, Optional, Union
from torch.nn.utils.rnn import pad_sequence
from typing import Sequence, Union
import torch
from torch.nn.utils.rnn import pad_sequence
class Tokenizer:
......@@ -67,8 +69,10 @@ class Tokenizer:
return self.model.decode(t,
skip_special_tokens=skip_special_tokens)
class Preprocessor:
def __init__(self, tokenizer:Tokenizer):
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
......@@ -87,13 +91,13 @@ class Preprocessor:
ids, ids' length and requested output length
"""
if isinstance(prompts, str):
input0 = [[prompts]]
_ = [[prompts]]
elif isinstance(prompts, Sequence):
input0 = [[prompt] for prompt in prompts]
_ = [[prompt] for prompt in prompts]
else:
assert 0, f'str or Sequence[str] prompts are expected but got ' \
f'{type(prompts)}'
start_ids = [
torch.IntTensor(self.tokenizer.encode(prompt))
for prompt in prompts
......@@ -106,7 +110,8 @@ class Preprocessor:
class Postprocessor:
def __init__(self, tokenizer:Tokenizer):
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
......@@ -128,4 +133,4 @@ class Postprocessor:
for tokens, _len in zip(output_ids, seqlen):
output = self.tokenizer.decode(tokens[:_len])
outputs.append(output)
return outputs
\ No newline at end of file
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Iterable
import sys
import os.path as osp
import torch
import sys
from typing import Iterable, List
import numpy as np
import lmdeploy
from lmdeploy.model import MODELS
from .tokenizer import Tokenizer, Preprocessor, Postprocessor
import torch
from torch.nn.utils.rnn import pad_sequence
import lmdeploy
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm
import _turbomind as _tm # noqa: E402
def _stop_words(stop_words: List[int]):
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
all(isinstance(elem, int) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_words
......@@ -136,8 +137,9 @@ class TurboMindInstance:
input_ids = [torch.IntTensor(ids) for ids in input_ids]
input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.eos_id)
input_ids = pad_sequence(input_ids,
batch_first=True,
padding_value=self.eos_id)
input_lengths = input_lengths.detach().cpu().numpy()
if isinstance(session_id, int):
......@@ -147,8 +149,9 @@ class TurboMindInstance:
inputs = dict(
input_ids=input_ids,
input_lengths=input_lengths,
request_output_len=np.full(
input_lengths.shape, request_output_len, dtype=np.uint32),
request_output_len=np.full(input_lengths.shape,
request_output_len,
dtype=np.uint32),
runtime_top_k=_broadcast_np(top_k, np.uint32),
runtime_top_p=_broadcast_np(top_p, np.float32),
temperature=_broadcast_np(temperature, np.float32),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment