Commit 4f3e977c authored by Haotian Tang's avatar Haotian Tang
Browse files

[Major] Add TinyChat and demo.

parent 79048993
......@@ -12,7 +12,13 @@ The current release supports:
- Efficient CUDA kernel implementation for fast inference (support context and decoding stage).
- Examples on 4-bit inference of an instruction-tuned model (Vicuna) and multi-modal LM (LLaVA).
![TinyChat on RTX 4090: W4A16 is 2.3x faster than FP16](./tinychat/figures/4090_example.gif)
Check out [TinyChat](tinychat), which delievers 2.3x faster inference performance for the **LLaMA-2** chatbot!
## News
- [2023/07] 🔥 We released TinyChat, an efficient and minimal chatbot interface based on AWQ. LLama-2-chat models are supported! Check out our implementation [here](tinychat).
- [2023/07] 🔥 We added AWQ support and pre-computed search results for Llama-2 models (7B & 13B). Checkout our model zoo [here](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo)!
- [2023/07] We extended the support for more LLM models including MPT, Falcon, and BLOOM.
......@@ -40,7 +46,7 @@ pip install --upgrade pip # enable PEP 660 support
pip install -e .
```
3. Install efficient W4A16 (4-bit weight, 16-bit activation) CUDA kernel
3. Install efficient W4A16 (4-bit weight, 16-bit activation) CUDA kernel and optimized FP16 kernels (e.g. layernorm, positional encodings).
```
cd awq/kernels
python setup.py install
......
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
*/
#include <torch/extension.h>
#include <cuda_fp16.h>
#include "reduction.cuh"
#include "layernorm.h"
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
static inline __device__ float to_float(half src)
{
return __half2float(src);
}
static inline __device__ float to_float(float src)
{
return src;
}
template<typename T>
__global__ void generalT5LayerNorm(
const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n)
{
// layernorm module in the T5 style No bias and no subtraction of mean.
const int tid = threadIdx.x;
__shared__ float s_variance;
float variance = 0.0f;
float local_var_sum = 0.0f;
for (int i = tid; i < n; i += blockDim.x) {
float diff = to_float(__ldg(&input[blockIdx.x * n + i]));
local_var_sum += diff * diff;
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (float)n + layernorm_eps);
}
__syncthreads();
for (int i = tid; i < n; i += blockDim.x) {
output[blockIdx.x * n + i] =
clamp_inf_for_half<T>((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i])));
}
}
template<typename T>
void invokeGeneralT5LayerNorm(T* out,
const T* input,
const T* gamma,
// const T* beta,
const float layernorm_eps,
const int m,
const int n)
{
dim3 grid(m);
dim3 block(min(n, 1024));
/* For general cases, n is equal to hidden_units, e.g., 512/1024.
Since we have warp shuffle inside the code, block.x % 32 should be 0.
*/
if (n % 32 != 0) {
block.x = 1024;
}
block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x
/* should pay attention to the rsqrt precision*/
generalT5LayerNorm<T><<<grid, block>>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3
}
template void invokeGeneralT5LayerNorm(half* out,
const half* input,
const half* gamma,
// const half* beta,
const float layernorm_eps,
const int m,
const int n);
template void invokeGeneralT5LayerNorm(float* out,
const float* input,
const float* gamma,
// const half* beta,
const float layernorm_eps,
const int m,
const int n);
// input b, n, c
void layernorm_forward_cuda(
torch::Tensor _input,
torch::Tensor _gamma,
torch::Tensor _out,
float eps)
{
int m = _input.size(0) * _input.size(1);
int n = _input.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_input));
auto input = reinterpret_cast<half*>(_input.data_ptr<at::Half>());
auto gamma = reinterpret_cast<half*>(_gamma.data_ptr<at::Half>());
auto out = reinterpret_cast<half*>(_out.data_ptr<at::Half>());
invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n);
}
#include <torch/extension.h>
void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps);
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh
*/
#pragma once
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <float.h>
#include <type_traits>
static const float HALF_FLT_MAX = 65504.F;
#define FINAL_MASK 0xffffffff
template<typename T>
inline __device__ T add(T a, T b) {
return a + b;
}
template<>
inline __device__ half2 add(half2 a, half2 b) {
return __hadd2(a, b);
}
template<>
inline __device__ half add(half a, half b) {
return __hadd(a, b);
}
template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template<typename T>
__device__ __forceinline__ T clamp_inf_for_half(const float input)
{
return input;
}
template<>
__device__ __forceinline__ half clamp_inf_for_half(const float input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000));
}
#pragma once
#include <torch/extension.h>
void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
\ No newline at end of file
/*
Adapted from the VLLM project:
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
}
}
void rotary_embedding_neox(
torch::Tensor& positions, // [b, num_tokens]
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size]
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0) * query.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2);
int stride = num_heads * head_size;
// TORCH_CHECK(stride == key.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding_neox",
[&] {
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
head_size);
});
}
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
}
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
......@@ -107,7 +119,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
......@@ -465,4 +476,3 @@ torch::Tensor gemm_forward_cuda(
}
return _out_feats.sum(0);
}
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "gemm_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gemm_forward_cuda", &gemm_forward_cuda, "our sparse conv kernel");
}
......@@ -7,12 +7,17 @@ extra_compile_args = {
}
setup(
name="f16s4_gemm",
name="awq_inference_engine",
packages=find_packages(),
ext_modules=[
CUDAExtension(
name="f16s4_gemm",
sources=["pybind.cpp", "gemm_cuda_gen.cu"],
name="awq_inference_engine",
sources=[
"csrc/pybind.cpp",
"csrc/quantization/gemm_cuda_gen.cu",
"csrc/layernorm/layernorm.cu",
"csrc/position_embedding/pos_encoding_kernels.cu"
],
extra_compile_args=extra_compile_args,
),
],
......
import math
import torch
import torch.nn as nn
import f16s4_gemm # with CUDA kernels
import awq_inference_engine # with CUDA kernels
class ScaledActivation(nn.Module):
......@@ -89,7 +89,7 @@ class WQLinear(nn.Module):
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = f16s4_gemm.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
......
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
......@@ -9,7 +8,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
......@@ -17,7 +15,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
......@@ -25,13 +22,12 @@
"- [AWQ](https://github.com/mit-han-lab/llm-awq)\n",
"- [Pytorch](https://pytorch.org/)\n",
"- [Accelerate](https://github.com/huggingface/accelerate)\n",
"- [FastChat](https://github.com/lm-sys/FastChat)\n",
"- [Transformers](https://github.com/huggingface/transformers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
......@@ -39,16 +35,16 @@
"from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
"from awq.quantize.quantizer import real_quantize_model_weight\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
"from fastchat.serve.cli import SimpleChatIO\n",
"from fastchat.serve.inference import generate_stream \n",
"from fastchat.conversation import get_conv_template\n",
"from tinychat.demo import gen_params, stream_output\n",
"from tinychat.stream_generators import StreamGenerator\n",
"from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp\n",
"from tinychat.utils.prompt_templates import get_prompter\n",
"import os\n",
"# This demo only support single GPU for now\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
......@@ -69,12 +65,13 @@
"metadata": {},
"outputs": [],
"source": [
"model_path = \"\" # the path of vicuna-7b model\n",
"load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\""
"# model_path = \"\" # the path of vicuna-7b model\n",
"# load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\"\n",
"model_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b\"\n",
"load_quant_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
......@@ -87,24 +84,26 @@
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.50s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"* skipping lm_head\n"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b79a82b73ab4d9191ba54f5d0f8cb86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"real weight quantization...: 100%|██████████| 224/224 [00:26<00:00, 8.40it/s]\n"
"real weight quantization...(init only): 100%|███████████████████| 32/32 [00:11<00:00, 2.69it/s]\n",
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
]
}
],
......@@ -134,70 +133,126 @@
"name": "stdout",
"output_type": "stream",
"text": [
"User: How can I improve my time management skills?\n",
"ASSISTANT: Time management skills can be improved through a combination of techniques, such as setting clear goals, prioritizing tasks, and using time-saving tools and strategies. Here are some tips to help you improve your time management skills:\n",
"\n",
"1. Set clear goals: Establish clear and specific goals for what you want to achieve. This will help you prioritize your tasks and focus your efforts.\n",
"2. Prioritize tasks: Identify the most important tasks that need to be completed and prioritize them accordingly. Use the Eisenhower matrix to categorize tasks into urgent and important, important but not urgent, urgent but not important, and not urgent or important.\n",
"3. Use time-saving tools and strategies: Use tools like calendars, to-do lists, and time trackers to help you manage your time more effectively. Also, consider using time-saving strategies like batching, delegating, and automating tasks.\n",
"4. Practice time management techniques: Practice time management techniques like the Pomodoro technique, the 80/20 rule, and Parkinson's law to help you work more efficiently.\n",
"5. Learn to say no: Learn to say no to non-essential tasks and commitments to free up more time for what's important.\n",
"6. Take breaks: Take regular breaks throughout the day to recharge and refocus.\n",
"7. Review and adjust: Regularly review and adjust your time management strategies to ensure they are working for you.\n",
"\n",
"Remember, time management is a skill that takes time and practice to develop. Be patient with yourself and keep working on improving your time management skills.\n",
"exit...\n"
"[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.\n"
]
},
{
"data": {
"text/plain": [
"LlamaForCausalLM(\n",
" (model): LlamaModel(\n",
" (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): QuantLlamaAttention(\n",
" (qkv_proj): WQLinear(in_features=4096, out_features=12288, bias=False, w_bit=4, group_size=128)\n",
" (o_proj): WQLinear(in_features=4096, out_features=4096, bias=False, w_bit=4, group_size=128)\n",
" (rotary_emb): QuantLlamaRotaryEmbedding()\n",
" )\n",
" (mlp): QuantLlamaMLP(\n",
" (down_proj): WQLinear(in_features=11008, out_features=4096, bias=False, w_bit=4, group_size=128)\n",
" )\n",
" (input_layernorm): FTLlamaRMSNorm()\n",
" (post_attention_layernorm): FTLlamaRMSNorm()\n",
" )\n",
" )\n",
" (norm): FTLlamaRMSNorm()\n",
" )\n",
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv = get_conv_template(\"vicuna_v1.1\")\n",
"chatio = SimpleChatIO()\n",
"\n",
"inp = \"How can I improve my time management skills?\"\n",
"print(\"User:\", inp)\n",
"\n",
"make_quant_attn(model, \"cuda:0\")\n",
"make_quant_norm(model)\n",
"make_fused_mlp(model)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"USER: Show me some attractions in Boston.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ASSISTANT: 1. Boston Public Library\n",
"2. Fenway Park\n",
"3. Harvard Square\n",
"4. Boston Common\n",
"5. Freedom Trail\n",
"6. Museum of Fine Arts\n",
"7. Isabella Stewart Gardner Museum\n",
"8. Paul Revere House\n",
"9. New England Aquarium\n",
"10. Museum of Science\n",
"==================================================\n",
"Speed of Inference\n",
"--------------------------------------------------\n",
"Context Stage : 7.18 ms/token\n",
"Generation Stage : 9.49 ms/token\n",
"Average Speed : 8.53 ms/token\n",
"==================================================\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"USER: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EXIT...\n"
]
}
],
"source": [
"model_prompter = get_prompter(\"llama\", model_path)\n",
"stream_generator = StreamGenerator\n",
"count = 0\n",
"while True:\n",
" if not inp:\n",
" try:\n",
" inp = chatio.prompt_for_input(conv.roles[0])\n",
" except EOFError:\n",
" inp = \"\"\n",
" if not inp:\n",
" print(\"exit...\")\n",
" # Get input from the user\n",
" input_prompt = input(\"USER: \")\n",
" if input_prompt == \"\":\n",
" print(\"EXIT...\")\n",
" break\n",
"\n",
" conv.append_message(conv.roles[0], inp)\n",
" conv.append_message(conv.roles[1], None)\n",
"\n",
" generate_stream_func = generate_stream\n",
" prompt = conv.get_prompt()\n",
"\n",
" gen_params = {\n",
" \"model\": model_path,\n",
" \"prompt\": prompt,\n",
" \"temperature\": 0.3,\n",
" \"repetition_penalty\": 1.0,\n",
" \"max_new_tokens\": 512,\n",
" \"stop\": conv.stop_str,\n",
" \"stop_token_ids\": conv.stop_token_ids,\n",
" \"echo\": False,\n",
" }\n",
"\n",
" chatio.prompt_for_output(conv.roles[1])\n",
" output_stream = generate_stream_func(model, tokenizer, gen_params, \"cuda\")\n",
" outputs = chatio.stream_output(output_stream)\n",
" conv.update_last_message(outputs.strip())\n",
" \n",
" inp = None"
" model_prompter.insert_prompt(input_prompt)\n",
" output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=\"cuda:0\")\n",
" outputs = stream_output(output_stream) \n",
" model_prompter.update_template(outputs)\n",
" count += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "awq",
"display_name": "Python (awq)",
"language": "python",
"name": "python3"
"name": "awq"
},
"language_info": {
"codemirror_mode": {
......@@ -209,10 +264,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
......@@ -7,7 +7,7 @@ name = "awq"
version = "0.1.0"
description = "An efficient and accurate low-bit weight quantization(INT3/4) method for LLMs."
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
......@@ -15,12 +15,14 @@ classifiers = [
dependencies = [
"accelerate", "sentencepiece", "tokenizers>=0.12.1",
"torch", "torchvision",
"transformers>=4.28.0",
"lm_eval"
"transformers>=4.31.0",
"lm_eval", "texttable",
"toml", "attributedict",
"protobuf"
]
[tool.setuptools.packages.find]
exclude = ["results*", "scripts*", "examples*"]
[tool.wheel]
exclude = ["results*", "scripts*", "examples*"]
\ No newline at end of file
exclude = ["results*", "scripts*", "examples*"]
# TinyChat: Efficient and Minimal Chatbot with AWQ
We introduce TinyChat, a cutting-edge chatbot interface designed for minimal resource consumption and fast inference speed on GPU platforms. It allows for seamless deployment on low-power edge devices like the NVIDIA Jetson Orin, empowering users with a responsive conversational experience like never before.
The current release supports:
- LLaMA-2-7B/13B-chat;
- Vicuna;
- MPT-chat;
- Falcon-instruct.
## Contents
- [Examples](#examples)
- [Usage](#usage)
- [Reference](#reference)
## Examples
Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is 2.3x faster on RTX 4090 and 1.4x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.)
* TinyChat on RTX 4090:
![TinyChat on RTX 4090: W4A16 is 2.3x faster than FP16](./figures/4090_example.gif)
* TinyChat on Jetson Orin:
![TinyChat on Jetson Orin: W4A16 is 1.4x faster than FP16](./figures/orin_example.gif)
## Usage
1. Please follow the [AWQ installation guidance](https://github.com/mit-han-lab/llm-awq#readme) to install AWQ and its dependencies.
2. Download the pretrained instruction-tuned LLMs:
- For LLaMA-2-chat, please refer to [this link](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf);
- For Vicuna, please refer to [this link](https://huggingface.co/lmsys/);
- For MPT-chat, please refer to [this link](https://huggingface.co/mosaicml/mpt-7b-chat);
- For Falcon-instruct, please refer to [this link](https://huggingface.co/tiiuae/falcon-7b-instruct).
3. Quantize instruction-tuned LLMs with AWQ:
- We provide pre-computed AWQ search results for multiple model families, including LLaMA, OPT, Vicuna, and LLaVA. To get the pre-computed AWQ search results, run:
```bash
# git lfs install # install git lfs if not already
git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
```
- You may run a one-line starter below:
```bash
./scripts/llama2_demo.sh
```
Alternatively, you may go through the process step by step. We will demonstrate the quantization process with LLaMA-2. For all other models except Falcon, one only needs to change the `model_path` and saving locations. For Falcon-7B, we also need to change `q_group_size` from 128 to 64.
- Perform AWQ search and save search results (we already did it for you):
```bash
mkdir awq_cache
python -m awq.entry --model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt
```
- Generate real quantized weights (INT4):
```bash
mkdir quant_cache
python -m awq.entry --model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/llama-2-7b-chat-w4-g128.pt \
--q_backend real --dump_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt
```
4. Run the TinyChat demo:
```bash
cd tinychat
python demo.py --model_type llama \
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--q_group_size 128 --load_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt \
    --precision W4A16
```
Note: if you use Falcon-7B-instruct, please remember to also change `q_group_size` to 64. You may also run the following command to execute the chatbot in FP16 to compare the speed and quality of language generation:
```bash
python demo.py --model_type llama \
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--precision W16A16
```
## Reference
TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat).
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
from attributedict.collections import AttributeDict
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# opt_params in TinyLLMEngine
gen_params = AttributeDict([
("seed", -1), # RNG seed
("n_threads", 1), # TODO: fix this
("n_predict", 512), # new tokens to predict
("n_parts", -1), # amount of model parts (-1: determine from model dimensions)
("n_ctx", 512), # context size
("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS)
("n_keep", 0), # number of tokens to keep from initial prompt
("n_vocab", 50272), # vocabulary size
# sampling parameters
("logit_bias", dict()), # logit bias for specific tokens: <int, float>
("top_k", 40), # <= 0 to use vocab size
("top_p", 0.95), # 1.0 = disabled
("tfs_z", 1.00), # 1.0 = disabled
("typical_p", 1.00), # 1.0 = disabled
("temp", 0.70), # 1.0 = disabled
("repeat_penalty", 1.10), # 1.0 = disabled
("repeat_last_n", 64), # last n tokens to penalize (0 = disable penalty, -1 = context size)
("frequency_penalty", 0.00),# 0.0 = disabled
("presence_penalty", 0.00), # 0.0 = disabled
("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
("mirostat_tau", 5.00), # target entropy
("mirostat_eta", 0.10), # learning rate
])
def stream_output(output_stream):
print(f"ASSISTANT: ", end="", flush=True)
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
if "timing" in outputs and outputs["timing"] is not None:
timing = outputs["timing"]
context_tokens = timing["context_tokens"]
context_time = timing["context_time"]
total_tokens = timing["total_tokens"]
generation_time_list = timing["generation_time_list"]
generation_tokens = len(generation_time_list)
average_speed = (context_time + np.sum(generation_time_list)) / (context_tokens + generation_tokens)
print("=" * 50)
print("Speed of Inference")
print("-" * 50)
# print(f"Context Stage : {context_time/context_tokens * 1000:.2f} ms/token")
print(f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token")
# print(f"Average Speed : {average_speed * 1000:.2f} ms/token")
print("=" * 50)
# print("token num:", total_tokens)
# print("Model total Time = ", (context_time + np.sum(generation_time_list))*1000, "ms" )
return " ".join(output_text)
def device_warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default='LLaMa', help='type of the model')
parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model')
parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision')
parser.add_argument('--device' , type=str, default='cuda')
parser.add_argument('--q_group_size', type=int, default=128)
parser.add_argument('--load_quant', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt', help='path to the pre-quanted 4-bit weights')
args = parser.parse_args()
assert args.model_type.lower() in ["llama", "falcon", "mpt"], "We only support llama & falcon & mpt now"
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
gen_params.n_predict = 512
gen_params.n_vocab = 32000
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
# config.init_device="meta"
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True)
modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
if args.precision == "W4A16":
if args.model_type.lower() == "llama":
model = load_awq_llama_fast(model, args.load_quant, 4, args.q_group_size, args.device)
else:
model = load_awq_model(model, args.load_quant, 4, args.q_group_size, args.device)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device)
# device warm up
device_warmup(args.device)
if args.model_type.lower() == 'falcon':
stream_generator = FalconStreamGenerator
else:
stream_generator = StreamGenerator
# Optimize AWQ quantized model
if args.precision == "W4A16" and args.model_type.lower() == 'llama':
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
make_quant_attn(model, args.device)
make_quant_norm(model)
make_fused_mlp(model)
model_prompter = get_prompter(args.model_type, args.model_path)
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
count = 0
while True:
# Get input from the user
input_prompt = input("USER: ")
if input_prompt == "":
print("EXIT...")
break
model_prompter.insert_prompt(input_prompt)
output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=args.device, stop_token_ids = stop_token_ids)
outputs = stream_output(output_stream)
model_prompter.update_template(outputs)
count += 1
from .fused_norm import *
from .fused_attn import *
from .fused_mlp import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment