Commit 3253240a authored by xiabo's avatar xiabo
Browse files

对应官方最新版本0.1.0主要增加page Attention

修改测试用例
parent a8ce8d27
...@@ -14,16 +14,26 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht ...@@ -14,16 +14,26 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
persistent batch 推理:进一步优化模型执行效率。 persistent batch 推理:进一步优化模型执行效率。
LMdeploy官方github地址:[https://github.com/InternLM/lmdeploy](https://github.com/InternLM/lmdeploy) LMdeploy官方github地址:[https://github.com/InternLM/lmdeploy](https://github.com/InternLM/lmdeploy)
## 支持模型 ## 支持模型
| 模型 | 模型并行 | FP16 | KV INT8 | | 模型 | 模型并行 | FP16 |
| :----------: | :------: | :--: | :-----: | | :----------: | :------: | :--: |
| Llama | Yes | Yes | Yes | | Llama | Yes | Yes |
| Llama2 | Yes | Yes | Yes | | Llama2 | Yes | Yes |
| InternLM-7B | Yes | Yes | Yes | | InternLM-7B | Yes | Yes |
| InternLM-20B | Yes | Yes | Yes | | InternLM-20B | Yes | Yes |
| QWen-7B | Yes | Yes | Yes | | QWen-7B | Yes | Yes |
| QWen-14B | Yes | Yes | Yes | | QWen-14B | Yes | Yes |
| Baichuan-7B | Yes | Yes | Yes | | QWen-72B | Yes | Yes |
| Baichuan2-7B | Yes | Yes | No | | Baichuan-7B | Yes | Yes |
| Baichuan2-7B | Yes | Yes |
| wizardlM | Yes | Yes |
| vicuna | Yes | Yes |
| puyu | Yes | Yes |
| codellama | Yes | Yes |
| solar | Yes | Yes |
| ultracm | Yes | Yes |
| ultralm | Yes | Yes |
| yi | Yes | Yes |
## 安装 ## 安装
...@@ -32,7 +42,7 @@ LMdeploy官方github地址:[https://github.com/InternLM/lmdeploy](https://github ...@@ -32,7 +42,7 @@ LMdeploy官方github地址:[https://github.com/InternLM/lmdeploy](https://github
#### 编译环境准备 #### 编译环境准备
下载光源的镜像,起dcoker 下载光源的镜像,起dcoker
``` ```
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:lmdeploy-dtk2310-torch1.13-py38 docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:lmdeploy1.0-dtk23.10-torch1.13-py38
# <Image ID>用上面拉取docker镜像的ID替换 # <Image ID>用上面拉取docker镜像的ID替换
# <Host Path>主机端路径 # <Host Path>主机端路径
...@@ -80,25 +90,25 @@ cd dist && pip3 install lmdeploy* ...@@ -80,25 +90,25 @@ cd dist && pip3 install lmdeploy*
## 模型服务 ## 模型服务
### 部署 [LLaMA](https://huggingface.co/huggyllama) 服务 ### 模型转换
请从[这里](https://huggingface.co/huggyllama) 下载 llama 模型,参考如下命令部署服务:
以7B为例:
``` ```
1、模型转换 # <model_name> 模型的名字 ('llama', 'internlm', 'vicuna', 'wizardlM', 'internlm-chat-7b', 'internlm-chat', 'internlm-chat-7b-8k', 'internlm-chat-20b', 'internlm-20b', 'baichuan-7b', 'baichuan2-7b', 'puyu', 'llama2', 'qwen-7b', 'qwen-14b', 'qwen-72b', 'codellama', 'solar', 'ultralm', 'ultracm', 'yi')
# <model_name> 模型的名字 ('llama', 'internlm', 'vicuna', 'internlm-chat-7b', 'internlm-chat', 'internlm-chat-7b-8k', 'internlm-chat-20b', 'internlm-20b', 'baichuan-7b', 'baichuan2-7b', 'llama2', 'qwen-7b', 'qwen-14b',)
# <model_path> 模型路径 # <model_path> 模型路径
# <model_format> 模型的格式 ('llama', 'hf', 'qwen' # <model_format> 模型的格式 ('llama', 'hf', None。可以不写默认None,代码会根据模型选择格式
# <tokenizer_path> tokenizer模型的路径(默认None,会去model_path里面找qwen.tiktoken) # <tokenizer_path> tokenizer模型的路径(默认None,会去model_path里面找对应的其他模型:'tokenizer.model',千问:'qwen.tiktoken'
# <model_format> 保存输出的目标路径(默认./workspace) # <model_format> 保存输出的目标路径(默认./workspace)
# <tp> 用于张量并行的GPU数量应该是2^n # <tp> 用于张量并行的GPU数量应该是2^n
lmdeploy convert --model_name llama --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_llama --tp 1 lmdeploy convert --model_name ${model_name} --model_path ${model_path} --model_format ${model_format} --tokenizer_path ${tokenizer_path} --dst_path ${dst_path} --tp ${tp}
```
2、运行 ### 运行
# bash界面运行 #### bash界面运行
lmdeploy chat turbomind --model_path ./workspace_llama --tp 1 # 输入问题后执行2次回车进行推理 ```
# <model_path>:转换后的模型路径
# 在服务器界面运行: lmdeploy chat turbomind --model_path ${model_path} --tp ${tp} # tp要和转模型的tp保持一直 # 输入问题后执行2次回车进行推理
```
#### web页面方式交互:
```
在bash端运行: 在bash端运行:
# <model_path_or_server> 部署模型的路径或tritonserver URL或restful api URL。前者用于与gradio直接运行服务。后者用于默认情况下使用tritonserver运行。如果输入URL是restful api。请启用另一个标志“restful_api”。 # <model_path_or_server> 部署模型的路径或tritonserver URL或restful api URL。前者用于与gradio直接运行服务。后者用于默认情况下使用tritonserver运行。如果输入URL是restful api。请启用另一个标志“restful_api”。
# <server_name> gradio服务器的ip地址 # <server_name> gradio服务器的ip地址
...@@ -107,95 +117,42 @@ lmdeploy chat turbomind --model_path ./workspace_llama --tp 1 # 输入问题 ...@@ -107,95 +117,42 @@ lmdeploy chat turbomind --model_path ./workspace_llama --tp 1 # 输入问题
# <tp> 用于张量并行的GPU数量应该是2^n (和模型转换的时候保持一致) # <tp> 用于张量并行的GPU数量应该是2^n (和模型转换的时候保持一致)
# <restful_api> modelpath_or_server的标志(默认是False) # <restful_api> modelpath_or_server的标志(默认是False)
lmdeploy serve gradio --model_path_or_server ./workspace_llama --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False lmdeploy serve gradio --model_path_or_server ${model_path_or_server} --server_name ${ip} --server_port ${pord} --batch_size 32 --tp ${tp}
在网页上输入{ip}:{pord}即可进行对话
```
### 部署 [llama2](https://huggingface.co/meta-llama) 服务
请从[这里](https://huggingface.co/meta-llama) 下载 llama2 模型,参考如下命令部署服务:
以7B为例:
```
1、模型转换
lmdeploy convert --model_name llama2 --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_llama2 --tp 1 #
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_llama2 --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_llama2 --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
```
### 部署 [internlm](https://huggingface.co/internlm/) 服务
请从[这里](https://huggingface.co/internlm) 下载 internlm 模型,参考如下命令部署服务:
以7B为例:
``` ```
1、模型转换 在网页上输入{ip}:{pord}即可进行对话, **需要保证'{ip}:{pord}'在外部浏览器中的可访问性**
lmdeploy convert --model_name model_name --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_intern --tp 1 # 根据模型的类型选择model_name是internlm-chat还是internlm
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_intern --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_intern --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话 #### 使用api-server
``` 启动server:
### 部署 [baichuan](https://huggingface.co/baichuan-inc) 服务
请从[这里](https://huggingface.co/baichuan-inc) 下载 baichuan 模型,参考如下命令部署服务:
以7B为例:
``` ```
1、模型转换
lmdeploy convert --model_name baichuan-7b --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_baichuan --tp 1
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_baichuan --tp 1
# 在服务器界面运行:
在bash端运行: 在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_baichuan --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False # --instance_num: turbomind推理实例的个数。模型支持的并发数 默认32
lmdeploy serve api_server ${model_path} --server_name ${server_ip} --server_port ${server_port} --instance_num ${instance_num} --tp ${tp}
在网页上输入{ip}:{pord}即可进行对话
``` ```
用户将下面命令输出的 http url 复制到浏览器打开,详细查看所有的 API 及其使用方法。 请一定查看http://{server_ip}:{server_port}!!! 请一定查看http://{server_ip}:{server_port}!!! 请一定查看http://{server_ip}:{server_port}!!! 重要的事情说三遍。
### 部署 [baichuan2](https://huggingface.co/baichuan-inc) 服务 CLI client
请从[这里](https://huggingface.co/baichuan-inc) 下载 baichuan2 模型,参考如下命令部署服务: restful api 服务可以通过客户端测试,例如
以7B为例:
``` ```
1、模型转换 # restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy convert --model_name baichuan2-7b --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_baichuan2 --tp 1 lmdeploy serve api_client api_server_url
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_baichuan2 --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_baichuan2 --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
``` ```
webui
### 部署 [qwen](https://huggingface.co/Qwen) 服务 也可以直接用 webui 测试使用 restful-api。
请从[这里](https://huggingface.co/Qwen) 下载 qwen 模型,参考如下命令部署服务:
以7B为例:
``` ```
1、模型转换 # api_server_url 就是 api_server 产生的,比如 http://localhost:23333
lmdeploy convert --model_name qwen-7b --model_path /path/to/model --model_format qwen --tokenizer_path None --dst_path ./workspace_qwen --tp 1 # server_name 和 server_port 是用来提供 gradio ui 访问服务的
2、运行 # 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
# bash界面运行 lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
lmdeploy chat turbomind --model_path ./workspace_qwen --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_qwen --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
``` ```
api-server的详细使用可以参照![这里](docs/zh_cn/restful_api.md)的文档
codellama模型的部署可以参照![codellama](docs/zh_cn/supported_models/codellama.md)
## result ## result
![qwen推理](docs/dcu/qwen推理.gif) ![qwen推理](docs/dcu/qwen推理.gif)
### 详细可参考 [docs](./docs/zh_cn/serving.md) ### 详细可参考 [docs](./docs/zh_cn/serving.md)
## 版本号查询 ## 版本号查询
- python -c "import lmdeploy; lmdeploy.\_\_version__",版本号与官方版本同步,查询该软件的版本号,例如0.0.6 - python -c "import lmdeploy; lmdeploy.\_\_version__",版本号与官方版本同步,查询该软件的版本号,例如0.1.0
## Known Issue ## Known Issue
- -
......
...@@ -10,10 +10,10 @@ from threading import Thread ...@@ -10,10 +10,10 @@ from threading import Thread
from typing import List from typing import List
import numpy as np import numpy as np
from pynvml import (NVMLError, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, # from pynvml import (NVMLError, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo, nvmlDeviceGetName, # nvmlDeviceGetMemoryInfo, nvmlDeviceGetName,
nvmlDeviceGetPowerState, nvmlDeviceGetTemperature, # nvmlDeviceGetPowerState, nvmlDeviceGetTemperature,
nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion) # nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion)
from tqdm import tqdm from tqdm import tqdm
from lmdeploy.turbomind import TurboMind from lmdeploy.turbomind import TurboMind
...@@ -186,76 +186,76 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, ...@@ -186,76 +186,76 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
percentiles, throughput, tm_model.gpu_count percentiles, throughput, tm_model.gpu_count
class MemoryMonitor: # class MemoryMonitor:
from multiprocessing import Manager # from multiprocessing import Manager
max_mem = Manager().Value('f', 0) # GB # max_mem = Manager().Value('f', 0) # GB
device_count = Manager().Value('f', 0) # device_count = Manager().Value('f', 0)
@staticmethod # @staticmethod
def nvidia_info(): # def nvidia_info():
# pip install nvidia-ml-py # # pip install nvidia-ml-py
nvidia_dict = { # nvidia_dict = {
'state': True, # 'state': True,
'nvidia_version': '', # 'nvidia_version': '',
'nvidia_count': 0, # 'nvidia_count': 0,
'gpus': [] # 'gpus': []
} # }
try: # try:
nvmlInit() # nvmlInit()
nvidia_dict['nvidia_version'] = nvmlSystemGetDriverVersion() # nvidia_dict['nvidia_version'] = nvmlSystemGetDriverVersion()
nvidia_dict['nvidia_count'] = nvmlDeviceGetCount() # nvidia_dict['nvidia_count'] = nvmlDeviceGetCount()
for i in range(nvidia_dict['nvidia_count']): # for i in range(nvidia_dict['nvidia_count']):
handle = nvmlDeviceGetHandleByIndex(i) # handle = nvmlDeviceGetHandleByIndex(i)
memory_info = nvmlDeviceGetMemoryInfo(handle) # memory_info = nvmlDeviceGetMemoryInfo(handle)
gpu = { # gpu = {
'gpu_name': nvmlDeviceGetName(handle), # 'gpu_name': nvmlDeviceGetName(handle),
'total': memory_info.total, # 'total': memory_info.total,
'free': memory_info.free, # 'free': memory_info.free,
'used': memory_info.used, # 'used': memory_info.used,
'temperature': f'{nvmlDeviceGetTemperature(handle, 0)}℃', # 'temperature': f'{nvmlDeviceGetTemperature(handle, 0)}℃',
'powerStatus': nvmlDeviceGetPowerState(handle) # 'powerStatus': nvmlDeviceGetPowerState(handle)
} # }
nvidia_dict['gpus'].append(gpu) # nvidia_dict['gpus'].append(gpu)
except NVMLError as _: # noqa # except NVMLError as _: # noqa
nvidia_dict['state'] = False # nvidia_dict['state'] = False
except Exception as _: # noqa # except Exception as _: # noqa
nvidia_dict['state'] = False # nvidia_dict['state'] = False
finally: # finally:
try: # try:
nvmlShutdown() # nvmlShutdown()
except: # noqa # except: # noqa
pass # pass
return nvidia_dict # return nvidia_dict
@classmethod # @classmethod
def mem_monitor(cls): # def mem_monitor(cls):
info = cls.nvidia_info() # info = cls.nvidia_info()
max_mem = 0 # max_mem = 0
mem_start = 0 # mem_start = 0
cls.device_count.value = len(info['gpus']) # cls.device_count.value = len(info['gpus'])
for used_total in info['gpus']: # for used_total in info['gpus']:
mem_start += used_total['used'] # mem_start += used_total['used']
while True: # while True:
info = cls.nvidia_info() # info = cls.nvidia_info()
used = 0 # used = 0
for used_total in info['gpus']: # for used_total in info['gpus']:
used += used_total['used'] # used += used_total['used']
if used > max_mem: # if used > max_mem:
max_mem = used # max_mem = used
cls.max_mem.value = (max_mem - mem_start) / (1 << 30) # cls.max_mem.value = (max_mem - mem_start) / (1 << 30)
@classmethod # @classmethod
def start(cls): # def start(cls):
cls._running = True # cls._running = True
from multiprocessing import Process # from multiprocessing import Process
cls.proc = Process(target=cls.mem_monitor) # cls.proc = Process(target=cls.mem_monitor)
cls.proc.start() # cls.proc.start()
@classmethod # @classmethod
def terminate(cls) -> float: # def terminate(cls) -> float:
"""Terminate the subprocess and return maximum memory.""" # """Terminate the subprocess and return maximum memory."""
cls.proc.kill() # cls.proc.kill()
return cls.max_mem.value # return cls.max_mem.value
@dataclass @dataclass
...@@ -345,7 +345,7 @@ def main(): ...@@ -345,7 +345,7 @@ def main():
for batch in args.concurrency: for batch in args.concurrency:
for prompt_tokens, completion_tokens in zip(args.prompt_tokens, for prompt_tokens, completion_tokens in zip(args.prompt_tokens,
args.completion_tokens): args.completion_tokens):
MemoryMonitor.start() # MemoryMonitor.start()
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
profile_target = partial(profile_throughput, profile_target = partial(profile_throughput,
...@@ -362,8 +362,10 @@ def main(): ...@@ -362,8 +362,10 @@ def main():
model_name, first_token_latency, percentiles, \ model_name, first_token_latency, percentiles, \
throughput_per_proc, tp = output[0] throughput_per_proc, tp = output[0]
time.sleep(5) # wait a while for releasing GPU mem time.sleep(5) # wait a while for releasing GPU mem
memory = MemoryMonitor.terminate() # memory = MemoryMonitor.terminate()
device_count = MemoryMonitor.device_count.value # device_count = MemoryMonitor.device_count.value
memory=0
device_count=0
results.append( results.append(
ProfileResult(model_name=model_name, ProfileResult(model_name=model_name,
batch=batch, batch=batch,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
add_executable(llama_triton_example llama_triton_example.cc) add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart #target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
target_link_libraries(llama_triton_example PUBLIC -lcublas -lrocblas -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list -lpthread) nvtx_utils word_list -lpthread)
......
...@@ -10,4 +10,4 @@ cmake .. \ ...@@ -10,4 +10,4 @@ cmake .. \
-DBUILD_MULTI_GPU=ON \ -DBUILD_MULTI_GPU=ON \
-DCMAKE_CUDA_FLAGS="-lineinfo" \ -DCMAKE_CUDA_FLAGS="-lineinfo" \
-DUSE_NVTX=OFF \ -DUSE_NVTX=OFF \
# -DBUILD_TEST=ON -DBUILD_TEST=ON
...@@ -73,3 +73,4 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) ...@@ -73,3 +73,4 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#add_subdirectory(gemm_s_f16) #add_subdirectory(gemm_s_f16)
add_subdirectory(decoder_multihead_attention)
...@@ -685,15 +685,15 @@ __device__ inline void m16n8k8(const uint32_t * A, const uint32_t * B, /*const f ...@@ -685,15 +685,15 @@ __device__ inline void m16n8k8(const uint32_t * A, const uint32_t * B, /*const f
__builtin_memcpy(smem+(base+2), B, sizeof(uint32_t)); __builtin_memcpy(smem+(base+2), B, sizeof(uint32_t));
__syncthreads(); __syncthreads();
/* 站在D的视角,每个进程负责D数据的计算,从0线程开始循环,获取一行A和两列B /* վ��D���ӽǣ�ÿ�����̸���D���ݵļ��㣬��0�߳̿�ʼѭ������ȡһ��A������B
s为B矩阵的线程号 sΪB������̺߳�
baseA为A的线程号 baseAΪA���̺߳�
baseB0为当前线程获取B的第一列,baseB1为当前线程获取B的第二列 baseB0Ϊ��ǰ�̻߳�ȡB�ĵ�һ�У�baseB1Ϊ��ǰ�̻߳�ȡB�ĵڶ���
*/ */
int s = baseId+(tid%4)*8, e = s+4; int s = baseId+(tid%4)*8, e = s+4;
for (int i = s; i < e; ++i) { for (int i = s; i < e; ++i) {
// A[0]->i A[1]->i+1 B[0]->i+2 // A[0]->i A[1]->i+1 B[0]->i+2
int baseA = (tid-tid%4+i-s)*3; // 前tid所处行的第一列的进程号+stride *3  
int baseB0 = i*3, baseB1 = (i+4)*3; int baseB0 = i*3, baseB1 = (i+4)*3;
f16mulf16addf32(smem[baseA], smem[baseB0+2], D, D); f16mulf16addf32(smem[baseA], smem[baseB0+2], D, D);
...@@ -1137,6 +1137,7 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp) ...@@ -1137,6 +1137,7 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
return int64; return int64;
} }
#ifdef ENABLE_BF16
// bfloat16 to int8 // bfloat16 to int8
inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp) inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp)
{ {
...@@ -1184,6 +1185,7 @@ inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp) ...@@ -1184,6 +1185,7 @@ inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
int16[3] = quant(a.w, scale, zp); int16[3] = quant(a.w, scale, zp);
return int64; return int64;
} }
#endif
// int8 to float32, then `vec_conversion` to target format // int8 to float32, then `vec_conversion` to target format
inline __device__ float dequant(int8_t a, const float scale, const float zp) inline __device__ float dequant(int8_t a, const float scale, const float zp)
......
...@@ -326,7 +326,7 @@ inline __device__ float2 half2_to_float2(uint32_t v) ...@@ -326,7 +326,7 @@ inline __device__ float2 half2_to_float2(uint32_t v)
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float bfloat16_to_float(__nv_bfloat16 h) inline __device__ float bfloat16_to_float(__nv_bfloat16 h)
{ {
return __bfloat162float(h); return __bfloat162float(h);
...@@ -344,7 +344,7 @@ inline __device__ float2 bfloat162_to_float2(__nv_bfloat162 v) ...@@ -344,7 +344,7 @@ inline __device__ float2 bfloat162_to_float2(__nv_bfloat162 v)
// asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); // asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
// return make_float2(bfloat16_to_float(lo), bfloat16_to_float(hi)); // return make_float2(bfloat16_to_float(lo), bfloat16_to_float(hi));
} }
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, uint16_t b) inline __device__ float add(float a, uint16_t b)
......
...@@ -5,12 +5,12 @@ add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv ...@@ -5,12 +5,12 @@ add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep) # --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass) #target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu) #add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
# target_compile_options(test_decoder_multihead_attention PRIVATE # target_compile_options(test_decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr) # --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
target_link_libraries(test_decoder_multihead_attention PRIVATE #target_link_libraries(test_decoder_multihead_attention PRIVATE
decoder_multihead_attention # decoder_multihead_attention
decoder_masked_multihead_attention # decoder_masked_multihead_attention
cublas) # cublas)
...@@ -181,7 +181,8 @@ inline __device__ void Store(T* dst, const Array<T, N>& src) ...@@ -181,7 +181,8 @@ inline __device__ void Store(T* dst, const Array<T, N>& src)
*(uint1*)dst = (const uint1&)src; *(uint1*)dst = (const uint1&)src;
} }
else { else {
static_assert(!std::is_same_v<T, T>); printf("=====array_ops.h 184\n");
// static_assert(!std::is_same_v<T, T>);
} }
} }
...@@ -200,7 +201,8 @@ inline __device__ void Ldg(Array<T, N>& dst, const T* src) ...@@ -200,7 +201,8 @@ inline __device__ void Ldg(Array<T, N>& dst, const T* src)
(uint&)dst = __ldg((const uint*)src); (uint&)dst = __ldg((const uint*)src);
} }
else { else {
static_assert(!std::is_same_v<T, T>); printf("=====array_ops.h 204\n");
// static_assert(!std::is_same_v<T, T>);
} }
} }
...@@ -219,7 +221,8 @@ inline __device__ void Lds(Array<T, N>& dst, const T* src) ...@@ -219,7 +221,8 @@ inline __device__ void Lds(Array<T, N>& dst, const T* src)
(uint1&)dst = *(const uint1*)src; (uint1&)dst = *(const uint1*)src;
} }
else { else {
static_assert(!std::is_same_v<T, T>); printf("=====array_ops.h 224\n");
// static_assert(!std::is_same_v<T, T>);
} }
} }
...@@ -377,7 +380,15 @@ struct ConvertKvCache<Ti, int8_t> { ...@@ -377,7 +380,15 @@ struct ConvertKvCache<Ti, int8_t> {
inline __device__ uint8_t round(float x) const inline __device__ uint8_t round(float x) const
{ {
uint32_t y; uint32_t y;
asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x)); printf("======arrat_ops.h 380\n");
// asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
if (x >= 255) {
y = 255;
} else if (x < 0) {
y = 0;
} else {
y = std::round(x);
}
return y; return y;
} }
...@@ -414,11 +425,11 @@ inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x) ...@@ -414,11 +425,11 @@ inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
static constexpr uint32_t m1 = 0x7614; static constexpr uint32_t m1 = 0x7614;
static constexpr uint32_t m2 = 0x7624; static constexpr uint32_t m2 = 0x7624;
static constexpr uint32_t m3 = 0x7634; static constexpr uint32_t m3 = 0x7634;
printf("======arrat_ops.h 417\n");
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0)); // asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1)); // asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2)); // asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3)); // asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
if (0) { // fused with dequantization if (0) { // fused with dequantization
PRAGMA_UNROLL PRAGMA_UNROLL
......
...@@ -25,7 +25,8 @@ struct DecoderMultiHeadAttentionParams { ...@@ -25,7 +25,8 @@ struct DecoderMultiHeadAttentionParams {
const float* __restrict__ rope_theta; const float* __restrict__ rope_theta;
// kv cache // kv cache
size_t layer_offset; // size_t layer_offset;
int layer_offset;
/// cache layout M,[N,H,x,D] /// cache layout M,[N,H,x,D]
/// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block /// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <climits> #include <climits>
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <cuda_pipeline_primitives.h> // #include <cuda_pipeline_primitives.h>
#include <type_traits> #include <type_traits>
#include "decoder_multihead_attention_params.h" #include "decoder_multihead_attention_params.h"
...@@ -92,8 +92,10 @@ struct DecoderMultiHeadAttentionKernel { ...@@ -92,8 +92,10 @@ struct DecoderMultiHeadAttentionKernel {
Tkv* __restrict__ k_cache_; // [S, D] Tkv* __restrict__ k_cache_; // [S, D]
Tkv* __restrict__ v_cache_; // [S, D] Tkv* __restrict__ v_cache_; // [S, D]
const void** __restrict__ k_cache_ptrs_; // const void** __restrict__ k_cache_ptrs_;
const void** __restrict__ v_cache_ptrs_; // const void** __restrict__ v_cache_ptrs_;
void** __restrict__ k_cache_ptrs_;
void** __restrict__ v_cache_ptrs_;
Tkv* __restrict__ smem_Kv_; Tkv* __restrict__ smem_Kv_;
float* __restrict__ smem_S_; float* __restrict__ smem_S_;
...@@ -325,18 +327,18 @@ struct DecoderMultiHeadAttentionKernel { ...@@ -325,18 +327,18 @@ struct DecoderMultiHeadAttentionKernel {
__device__ void CpAsyncWait() __device__ void CpAsyncWait()
{ {
__pipeline_wait_prior(kStages - 2); // __pipeline_wait_prior(kStages - 2);
} }
__device__ void CpAsyncCommit() __device__ void CpAsyncCommit()
{ {
__pipeline_commit(); // __pipeline_commit();
} }
__device__ void CpAsyncFlush() __device__ void CpAsyncFlush()
{ {
__pipeline_commit(); // __pipeline_commit();
__pipeline_wait_prior(0); // __pipeline_wait_prior(0);
} }
static constexpr int kKvVecPerThread = MapKv::kIterC; static constexpr int kKvVecPerThread = MapKv::kIterC;
......
...@@ -14,12 +14,15 @@ namespace turbomind { ...@@ -14,12 +14,15 @@ namespace turbomind {
#endif #endif
struct BlockIterator { struct BlockIterator {
const void** ptrs_; // const void** ptrs_;
const void* prefetch_; // const void* prefetch_;
void** ptrs_;
void* prefetch_;
BlockIterator() = default; BlockIterator() = default;
__device__ BlockIterator(const void** block_ptrs): ptrs_{block_ptrs} __device__ BlockIterator(/*const */void** block_ptrs): ptrs_{block_ptrs}
{ {
// prefetch first ptr // prefetch first ptr
prefetch_ = *ptrs_++; prefetch_ = *ptrs_++;
...@@ -111,7 +114,8 @@ struct Iterator { ...@@ -111,7 +114,8 @@ struct Iterator {
is_valid_s_ = offset_s_ < seq_len; is_valid_s_ = offset_s_ < seq_len;
} }
__device__ Iterator(const void** block_ptrs, // __device__ Iterator(const void** block_ptrs,
__device__ Iterator(void** block_ptrs,
int block_size, int block_size,
int layer_offset, int layer_offset,
int head_idx, int head_idx,
...@@ -258,25 +262,26 @@ struct Iterator { ...@@ -258,25 +262,26 @@ struct Iterator {
} }
#endif #endif
static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask) // static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{ // {
const int smem_int_ptr = cast_smem_ptr_to_uint(dst); // const int smem_int_ptr = cast_smem_ptr_to_uint(dst);
constexpr int cp_size = sizeof(AccessType); // constexpr int cp_size = sizeof(AccessType);
#if TURBOMIND_ARCH_SM80 // printf("======iterator.h 265\n");
// clang-format off // #if TURBOMIND_ARCH_SM80
asm volatile("{\n" // // clang-format off
" .reg .pred p;\n" // asm volatile("{\n"
" setp.ne.b32 p, %0, 0;\n" // " .reg .pred p;\n"
" @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n" // " setp.ne.b32 p, %0, 0;\n"
"}\n" ::"r"((int)mask), // " @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"r"(smem_int_ptr), // "}\n" ::"r"((int)mask),
"l"(src), // "r"(smem_int_ptr),
"n"(cp_size)); // "l"(src),
// clang-format on // "n"(cp_size));
#else // // clang-format on
assert(TURBOMIND_ARCH_SM80); // #else
#endif // assert(TURBOMIND_ARCH_SM80);
} // #endif
// }
static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask) static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{ {
...@@ -287,12 +292,12 @@ struct Iterator { ...@@ -287,12 +292,12 @@ struct Iterator {
__device__ void Prefetch(bool mask) __device__ void Prefetch(bool mask)
{ {
if constexpr (TURBOMIND_ARCH_SM80) { // if constexpr (TURBOMIND_ARCH_SM80) {
CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask); // CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
} // }
else { // else {
Copy(smem_ + dst_offset_, src_ + src_offset_, mask); Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
} // }
} }
__device__ void Load(AccessType (&frag)[ThreadMap::kIterC]) __device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
......
...@@ -14,19 +14,20 @@ namespace turbomind { ...@@ -14,19 +14,20 @@ namespace turbomind {
#define TURBOMIND_S4_DEQUANT_USE_FMA 0 #define TURBOMIND_S4_DEQUANT_USE_FMA 0
#endif #endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
#define TURBOMIND_ARCH_SM75 1 // #define TURBOMIND_ARCH_SM75 1
#else // #else
#define TURBOMIND_ARCH_SM75 0 // #define TURBOMIND_ARCH_SM75 0
#endif // #endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define TURBOMIND_ARCH_SM80 1 // #define TURBOMIND_ARCH_SM80 1
#else // #else
#define TURBOMIND_ARCH_SM80 0 // #define TURBOMIND_ARCH_SM80 0
#endif // #endif
constexpr int WARP_SIZE = 32; // constexpr int WARP_SIZE = 32;
constexpr int WARP_SIZE = 64;
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
...@@ -68,22 +69,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) ...@@ -68,22 +69,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// dependency if we issue immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0]) // : "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 // // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1]) // : "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 // // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2]) // : "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 // // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3]) // : "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf("=========common.h 86\n");
// I use inline PTX below because I am not sure if the compiler will emit // I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose // float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability. // performance reliability over code readability.
...@@ -101,13 +102,13 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) ...@@ -101,13 +102,13 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23 // // Convert elt_23
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45 // // Convert elt_45
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67 // // Convert elt_67
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result; return result;
} }
...@@ -130,27 +131,27 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source) ...@@ -130,27 +131,27 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
printf("=========common.h 133\n");
if (0) { // 1024 & 64 // if (0) { // 1024 & 64
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0)); // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1)); // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0)); // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1)); // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
} // }
else { // 64 only, trade 4 hfma2 with 2 shifts // else { // 64 only, trade 4 hfma2 with 2 shifts
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); // asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h[0] <<= 4; // h[0] <<= 4;
h[2] <<= 4; // h[2] <<= 4;
// we don't need to subtract the magic nums because zeros will go through the same dequant function // // we don't need to subtract the magic nums because zeros will go through the same dequant function
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros // // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
} // }
return result; return result;
} }
...@@ -158,62 +159,64 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source) ...@@ -158,62 +159,64 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{ {
uint32_t smem_int_ptr; uint32_t smem_int_ptr;
printf("=========common.h 161\n");
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" // asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr) // : "=r"(smem_int_ptr)
: "l"(ptr)); // : "l"(ptr));
return smem_int_ptr; return smem_int_ptr;
} }
__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr) __inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{ {
#if TURBOMIND_ARCH_SM75 printf("=========common.h 171\n");
asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" // #if TURBOMIND_ARCH_SM75
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) // asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "r"(smem_int_ptr)); // : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
#else // : "r"(smem_int_ptr));
assert(TURBOMIND_ARCH_SM75); // #else
#endif // assert(TURBOMIND_ARCH_SM75);
// #endif
} }
__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr) __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
{ {
#if TURBOMIND_ARCH_SM75 printf("=========common.h 183\n");
asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr)); // #if TURBOMIND_ARCH_SM75
#else // asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
assert(TURBOMIND_ARCH_SM75); // #else
#endif // assert(TURBOMIND_ARCH_SM75);
// #endif
} }
__inline__ __device__ void wait_flag(int* lock, int status, int thread_id) // __inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
{ // {
int state = 0; // int state = 0;
while (__syncthreads_and(state != status)) { // while (__syncthreads_and(state != status)) {
if (thread_id == 0) { // if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); // asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else // #else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); // asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif // #endif
} // }
} // }
__syncthreads(); // memory fence // __syncthreads(); // memory fence
} // }
__inline__ __device__ void release_flag(int* lock, int status, int thread_id) // __inline__ __device__ void release_flag(int* lock, int status, int thread_id)
{ // {
__syncthreads(); // memory fence // __syncthreads(); // memory fence
if (thread_id == 0) { // if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); // asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else // #else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); // asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif // #endif
} // }
} // }
__inline__ __device__ half2 apply_Q(const half2& x, const half2& q) __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{ {
...@@ -223,14 +226,14 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q) ...@@ -223,14 +226,14 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
auto& t = (const uint&)x; auto& t = (const uint&)x;
uint u, v; uint u, v;
if (TURBOMIND_S4_DEQUANT_USE_FMA) { // if (TURBOMIND_S4_DEQUANT_USE_FMA) {
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z)); // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
} // }
else { // else {
asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z)); // asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s)); // asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
} // }
printf("=========common.h 235\n");
return (half2&)v; return (half2&)v;
} }
......
...@@ -8,73 +8,73 @@ ...@@ -8,73 +8,73 @@
namespace turbomind { namespace turbomind {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B" // #define L2_CACHEHINT(size) ".L2::" #size "B"
#else // #else
#define L2_CACHEHINT(size) // #define L2_CACHEHINT(size)
#endif // #endif
template<typename T> // template<typename T>
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask) // __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{ // {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T); // constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16"); // static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
// clang-format off // // clang-format off
asm volatile("{\n" // asm volatile("{\n"
" .reg .pred p;\n" // " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" // " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global" L2_CACHEHINT(256) " [%1], [%2], %3;\n" // " @p cp.async.cg.shared.global" L2_CACHEHINT(256) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask), // "}\n" ::"r"((int)mask),
"r"(smem_int_ptr), // "r"(smem_int_ptr),
"l"(src), // "l"(src),
"n"(cp_size)); // "n"(cp_size));
// clang-format on // // clang-format on
#else // #else
assert(TURBOMIND_ARCH_SM80); // assert(TURBOMIND_ARCH_SM80);
#endif // #endif
} // }
template<typename T> // template<typename T>
__inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask) // __inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{ // {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T); // constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16"); // static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
// clang-format off // // clang-format off
asm volatile("{\n" // asm volatile("{\n"
" .reg .pred p;\n" // " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" // " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n" // " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask), // "}\n" ::"r"((int)mask),
"r"(smem_int_ptr), // "r"(smem_int_ptr),
"l"(src), // "l"(src),
"n"(cp_size)); // "n"(cp_size));
// clang-format on // // clang-format on
#else // #else
assert(TURBOMIND_ARCH_SM80); // assert(TURBOMIND_ARCH_SM80);
#endif // #endif
} // }
template<typename T> // template<typename T>
__inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask) // __inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{ // {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T); // constexpr int cp_size = sizeof(T);
// clang-format off // // clang-format off
asm volatile("{\n" // asm volatile("{\n"
" .reg .pred p;\n" // " .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n" // " setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n" // " @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask), // "}\n" ::"r"((int)mask),
"r"(smem_int_ptr), // "r"(smem_int_ptr),
"l"(src), // "l"(src),
"n"(cp_size)); // "n"(cp_size));
// clang-format on // // clang-format on
#else // #else
assert(TURBOMIND_ARCH_SM80); // assert(TURBOMIND_ARCH_SM80);
#endif // #endif
} // }
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES> template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES>
struct IteratorA { struct IteratorA {
...@@ -237,13 +237,13 @@ struct IteratorA { ...@@ -237,13 +237,13 @@ struct IteratorA {
__device__ void prefetch(bool mask) __device__ void prefetch(bool mask)
{ {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask); // cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else // #else
if (mask) { if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_); *(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
} }
#endif // #endif
} }
}; };
...@@ -424,13 +424,13 @@ struct IteratorQ { ...@@ -424,13 +424,13 @@ struct IteratorQ {
__device__ void prefetch(bool mask) __device__ void prefetch(bool mask)
{ {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask); // cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else // #else
if (mask) { if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_); *(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
} }
#endif // #endif
} }
}; };
...@@ -626,14 +626,14 @@ struct IteratorB { ...@@ -626,14 +626,14 @@ struct IteratorB {
__device__ void prefetch(bool mask) __device__ void prefetch(bool mask)
{ {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
cp_async_cg_B( // cp_async_cg_B(
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask); // smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
#else // #else
if (is_valid_n_ && mask) { if (is_valid_n_ && mask) {
*(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_)); *(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
} }
#endif // #endif
} }
}; };
......
...@@ -9,41 +9,41 @@ ...@@ -9,41 +9,41 @@
namespace turbomind { namespace turbomind {
__inline__ __device__ void // __inline__ __device__ void
mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c) // mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)
{ // {
#if TURBOMIND_ARCH_SM75 // #if TURBOMIND_ARCH_SM75
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a); // uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b); // uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float const* C = reinterpret_cast<float const*>(&c); // float const* C = reinterpret_cast<float const*>(&c);
float* D = reinterpret_cast<float*>(&d); // float* D = reinterpret_cast<float*>(&d);
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " // asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5}, {%6}, {%7,%8,%9,%10};\n" // "{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) // : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); // : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else // #else
assert(TURBOMIND_ARCH_SM75); // assert(TURBOMIND_ARCH_SM75);
#endif // #endif
} // }
__inline__ __device__ void __inline__ __device__ void
mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c) mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
{ {
#if TURBOMIND_ARCH_SM80 // #if TURBOMIND_ARCH_SM80
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a); // uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b); // uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float const* C = reinterpret_cast<float const*>(&c); // float const* C = reinterpret_cast<float const*>(&c);
float* D = reinterpret_cast<float*>(&d); // float* D = reinterpret_cast<float*>(&d);
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " // asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" // "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) // : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); // : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else // #else
const Array<half, 4>* _a = (const Array<half, 4>*)&a; // const Array<half, 4>* _a = (const Array<half, 4>*)&a;
const Array<half, 2>* _b = (const Array<half, 2>*)&b; // const Array<half, 2>* _b = (const Array<half, 2>*)&b;
mma_m16n8k8_row_col(d, _a[0], _b[0], c); // mma_m16n8k8_row_col(d, _a[0], _b[0], c);
mma_m16n8k8_row_col(d, _a[1], _b[1], d); // mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif // #endif
} }
__inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id) __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id)
...@@ -64,29 +64,29 @@ __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_ ...@@ -64,29 +64,29 @@ __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_
return (uint&)r; return (uint&)r;
} }
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a) // __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
{ // {
#if TURBOMIND_ARCH_SM75 // #if TURBOMIND_ARCH_SM75
uint d; // uint d;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a)); // asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a));
return d; // return d;
#else // #else
assert(TURBOMIND_ARCH_SM75); // assert(TURBOMIND_ARCH_SM75);
return 0; // return 0;
#endif // #endif
} // }
#endif // #endif
__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
{ {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(void)lane_id; // (void)lane_id;
return transpose_m8n8_b16_movmatrix(a); // return transpose_m8n8_b16_movmatrix(a);
#else // #else
return transpose_m8n8_b16_warp_shuffle(a, lane_id); // return transpose_m8n8_b16_warp_shuffle(a, lane_id);
#endif // #endif
} }
namespace ops { namespace ops {
...@@ -242,7 +242,7 @@ struct Gemm { ...@@ -242,7 +242,7 @@ struct Gemm {
constexpr int SLICE_GROUP = (SLICES + 7) / 8; constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE; constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1; const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads)); // asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
} }
} }
......
...@@ -50,7 +50,7 @@ if (NOT MSVC) ...@@ -50,7 +50,7 @@ if (NOT MSVC)
endif() endif()
add_executable(llama_gemm llama_gemm.cc) add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC cudart gpt_gemm_func memory_utils cuda_utils logger) target_link_libraries(llama_gemm PUBLIC -lrocblas cudart gpt_gemm_func memory_utils cuda_utils logger)
install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin) install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin)
find_package(Catch2 3 QUIET) find_package(Catch2 3 QUIET)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> // #include <cooperative_groups/reduce.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
......
...@@ -962,8 +962,8 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud ...@@ -962,8 +962,8 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud
// template class FlashAttentionOp<float>; // template class FlashAttentionOp<float>;
// template class FlashAttentionOp<half>; // template class FlashAttentionOp<half>;
#ifdef ENABLE_BF16 // #ifdef ENABLE_BF16
template class FlashAttentionOp<__nv_bfloat16>; // template class FlashAttentionOp<__nv_bfloat16>;
#endif // #endif
} // namespace turbomind } // namespace turbomind
...@@ -64,7 +64,8 @@ void UnifiedAttentionLayer<T>::allocateBuffer(size_t num_token, ...@@ -64,7 +64,8 @@ void UnifiedAttentionLayer<T>::allocateBuffer(size_t num_token,
k_buf_2_ = q_buf_2_ + local_head_num_ * bsz * max_q * size_per_head_; k_buf_2_ = q_buf_2_ + local_head_num_ * bsz * max_q * size_per_head_;
v_buf_2_ = k_buf_2_ + local_kv_head_num_ * bsz * max_q * size_per_head_; v_buf_2_ = k_buf_2_ + local_kv_head_num_ * bsz * max_q * size_per_head_;
if (use_fmha_) { // if (use_fmha_) {
if (0) {
FlashAttentionOp<T> flash_attention(bsz, local_head_num_, max_k, max_q, size_per_head_); FlashAttentionOp<T> flash_attention(bsz, local_head_num_, max_k, max_q, size_per_head_);
if (flash_attention.get_workspace_size() > 0) { if (flash_attention.get_workspace_size() > 0) {
qk_buf_float_ = qk_buf_float_ =
...@@ -106,7 +107,7 @@ void UnifiedAttentionLayer<T>::freeBuffer() ...@@ -106,7 +107,7 @@ void UnifiedAttentionLayer<T>::freeBuffer()
allocator_->free((void**)(&q_buf_2_)); allocator_->free((void**)(&q_buf_2_));
allocator_->free((void**)(&qkv_buf_3_)); allocator_->free((void**)(&qkv_buf_3_));
allocator_->free((void**)&qk_buf_float_); // allocator_->free((void**)&qk_buf_float_);
allocator_->free((void**)(&k_cache_buf_)); allocator_->free((void**)(&k_cache_buf_));
allocator_->free((void**)(&qk_buf_)); allocator_->free((void**)(&qk_buf_));
allocator_->free((void**)(&qkv_buf_2_)); allocator_->free((void**)(&qkv_buf_2_));
...@@ -346,7 +347,8 @@ void UnifiedAttentionLayer<T>::prefill(T* output, ...@@ -346,7 +347,8 @@ void UnifiedAttentionLayer<T>::prefill(T* output,
stream_); stream_);
sync_check_cuda_error(); sync_check_cuda_error();
if (use_fmha_) { // if (use_fmha_) {
if (0) {
fusedMultiHeadAttention(output, fusedMultiHeadAttention(output,
q_buf_2_, q_buf_2_,
tmp_k_ptrs, tmp_k_ptrs,
...@@ -456,66 +458,66 @@ void UnifiedAttentionLayer<T>::decode(T* output, ...@@ -456,66 +458,66 @@ void UnifiedAttentionLayer<T>::decode(T* output,
} }
} }
template<typename T> // template<typename T>
void UnifiedAttentionLayer<T>::fusedMultiHeadAttention(T* output, // void UnifiedAttentionLayer<T>::fusedMultiHeadAttention(T* output,
const T* query, // const T* query,
T** key_cache_ptrs, // T** key_cache_ptrs,
T** val_cache_ptrs, // T** val_cache_ptrs,
size_t cache_layer_offset, // size_t cache_layer_offset,
T* attention_mask, // T* attention_mask,
int* cu_seqlens, // int* cu_seqlens,
int* context_lengths, // int* context_lengths,
int batch_size, // int batch_size,
int max_q_len, // int max_q_len,
int max_k_len, // int max_k_len,
int max_seq_len) // int max_seq_len)
{ // {
////////////////////////////////////////////// // //////////////////////////////////////////////
// flash attention // // flash attention
// flash attention 2 only support half inputs // // flash attention 2 only support half inputs
using AttentionOp = FlashAttentionOp<T>; // using AttentionOp = FlashAttentionOp<T>;
using Layout = typename AttentionOp::AttentionLayout; // using Layout = typename AttentionOp::AttentionLayout;
Layout layout_q{ // Layout layout_q{
int(local_head_num_ * max_q_len * size_per_head_), int(size_per_head_), int(max_q_len * size_per_head_)}; // int(local_head_num_ * max_q_len * size_per_head_), int(size_per_head_), int(max_q_len * size_per_head_)};
Layout layout_k{int(local_head_num_ * max_seq_len * size_per_head_), // Layout layout_k{int(local_head_num_ * max_seq_len * size_per_head_),
int(size_per_head_), // int(size_per_head_),
int(max_seq_len * size_per_head_), // int(max_seq_len * size_per_head_),
false, // false,
cache_layer_offset, // cache_layer_offset,
key_cache_ptrs}; // key_cache_ptrs};
Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_), // Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
int(size_per_head_), // int(size_per_head_),
int(max_seq_len * size_per_head_), // int(max_seq_len * size_per_head_),
false, // false,
cache_layer_offset, // cache_layer_offset,
val_cache_ptrs}; // val_cache_ptrs};
Layout layout_o{ // Layout layout_o{
int(local_head_num_ * max_q_len * size_per_head_), // int(local_head_num_ * max_q_len * size_per_head_),
int(local_head_num_ * size_per_head_), // int(local_head_num_ * size_per_head_),
int(size_per_head_), // int(size_per_head_),
true, // true,
}; // };
size_t group_size = size_t(local_head_num_ / local_kv_head_num_); // size_t group_size = size_t(local_head_num_ / local_kv_head_num_);
AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); // AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
typename AttentionOp::Params attn_params{output, // typename AttentionOp::Params attn_params{output,
(T*)query, // (T*)query,
k_cache_buf_, // k_cache_buf_,
v_cache_buf_, // v_cache_buf_,
attention_mask, // attention_mask,
qk_buf_float_, // qk_buf_float_,
cu_seqlens, // cu_seqlens,
nullptr, // nullptr,
nullptr, // nullptr,
context_lengths, // context_lengths,
group_size, // group_size,
layout_q, // layout_q,
layout_k, // layout_k,
layout_v, // layout_v,
layout_o}; // layout_o};
// // //
flash_attention(attn_params, stream_); // flash_attention(attn_params, stream_);
} // }
template<typename T> template<typename T>
void UnifiedAttentionLayer<T>::unfusedMultiHeadAttention(T* output, void UnifiedAttentionLayer<T>::unfusedMultiHeadAttention(T* output,
......
...@@ -47,6 +47,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM ...@@ -47,6 +47,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0),
model_dir); model_dir);
} }
#ifdef ENABLE_BF16
else if (data_type == "bf16") { else if (data_type == "bf16") {
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
return std::make_shared<LlamaTritonModel<__nv_bfloat16>>( return std::make_shared<LlamaTritonModel<__nv_bfloat16>>(
...@@ -59,6 +60,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM ...@@ -59,6 +60,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
ft::FT_CHECK(false); ft::FT_CHECK(false);
#endif #endif
} }
#endif
else { else {
return std::make_shared<LlamaTritonModel<float>>( return std::make_shared<LlamaTritonModel<float>>(
reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"),
...@@ -177,7 +179,8 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -177,7 +179,8 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
norm_eps_ = reader.GetFloat("llama", "norm_eps"); norm_eps_ = reader.GetFloat("llama", "norm_eps");
start_id_ = reader.GetInteger("llama", "start_id"); start_id_ = reader.GetInteger("llama", "start_id");
end_id_ = reader.GetInteger("llama", "end_id"); end_id_ = reader.GetInteger("llama", "end_id");
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); // use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
use_context_fmha_ = 0;
cache_block_seq_len_ = reader.GetInteger("llama", "cache_block_seq_len", 0); cache_block_seq_len_ = reader.GetInteger("llama", "cache_block_seq_len", 0);
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment