Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
3253240a
"docs/vscode:/vscode.git/clone" did not exist on "32b85dfa8d4a5fa54469ddc72be89d827c1ee9d6"
Commit
3253240a
authored
Jan 12, 2024
by
xiabo
Browse files
对应官方最新版本0.1.0主要增加page Attention
修改测试用例
parent
a8ce8d27
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
533 additions
and
543 deletions
+533
-543
README.md
README.md
+55
-98
benchmark/profile_generation.py
benchmark/profile_generation.py
+79
-77
examples/cpp/llama/CMakeLists.txt
examples/cpp/llama/CMakeLists.txt
+2
-1
generate.sh
generate.sh
+1
-1
src/turbomind/kernels/CMakeLists.txt
src/turbomind/kernels/CMakeLists.txt
+1
-0
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
...attention/decoder_masked_multihead_attention_template.cuh
+7
-5
src/turbomind/kernels/decoder_masked_multihead_attention_utils.h
...bomind/kernels/decoder_masked_multihead_attention_utils.h
+2
-2
src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
...bomind/kernels/decoder_multihead_attention/CMakeLists.txt
+6
-6
src/turbomind/kernels/decoder_multihead_attention/array_ops.h
...turbomind/kernels/decoder_multihead_attention/array_ops.h
+20
-9
src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
..._multihead_attention/decoder_multihead_attention_params.h
+2
-1
src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
...ultihead_attention/decoder_multihead_attention_template.h
+9
-7
src/turbomind/kernels/decoder_multihead_attention/iterator.h
src/turbomind/kernels/decoder_multihead_attention/iterator.h
+34
-29
src/turbomind/kernels/gemm_s_f16/common.h
src/turbomind/kernels/gemm_s_f16/common.h
+110
-107
src/turbomind/kernels/gemm_s_f16/cta_iterator.h
src/turbomind/kernels/gemm_s_f16/cta_iterator.h
+80
-80
src/turbomind/kernels/gemm_s_f16/gemm_template.h
src/turbomind/kernels/gemm_s_f16/gemm_template.h
+51
-51
src/turbomind/models/llama/CMakeLists.txt
src/turbomind/models/llama/CMakeLists.txt
+1
-1
src/turbomind/models/llama/llama_decoder_kernels.cu
src/turbomind/models/llama/llama_decoder_kernels.cu
+1
-1
src/turbomind/models/llama/llama_kernels.cu
src/turbomind/models/llama/llama_kernels.cu
+3
-3
src/turbomind/models/llama/unified_attention_layer.cc
src/turbomind/models/llama/unified_attention_layer.cc
+65
-63
src/turbomind/triton_backend/llama/LlamaTritonModel.cc
src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+4
-1
No files found.
README.md
View file @
3253240a
...
@@ -14,16 +14,26 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
...
@@ -14,16 +14,26 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
persistent batch 推理:进一步优化模型执行效率。
persistent batch 推理:进一步优化模型执行效率。
LMdeploy官方github地址:
[
https://github.com/InternLM/lmdeploy
](
https://github.com/InternLM/lmdeploy
)
LMdeploy官方github地址:
[
https://github.com/InternLM/lmdeploy
](
https://github.com/InternLM/lmdeploy
)
## 支持模型
## 支持模型
| 模型 | 模型并行 | FP16 | KV INT8 |
| 模型 | 模型并行 | FP16 |
| :----------: | :------: | :--: | :-----: |
| :----------: | :------: | :--: |
| Llama | Yes | Yes | Yes |
| Llama | Yes | Yes |
| Llama2 | Yes | Yes | Yes |
| Llama2 | Yes | Yes |
| InternLM-7B | Yes | Yes | Yes |
| InternLM-7B | Yes | Yes |
| InternLM-20B | Yes | Yes | Yes |
| InternLM-20B | Yes | Yes |
| QWen-7B | Yes | Yes | Yes |
| QWen-7B | Yes | Yes |
| QWen-14B | Yes | Yes | Yes |
| QWen-14B | Yes | Yes |
| Baichuan-7B | Yes | Yes | Yes |
| QWen-72B | Yes | Yes |
| Baichuan2-7B | Yes | Yes | No |
| Baichuan-7B | Yes | Yes |
| Baichuan2-7B | Yes | Yes |
| wizardlM | Yes | Yes |
| vicuna | Yes | Yes |
| puyu | Yes | Yes |
| codellama | Yes | Yes |
| solar | Yes | Yes |
| ultracm | Yes | Yes |
| ultralm | Yes | Yes |
| yi | Yes | Yes |
## 安装
## 安装
...
@@ -32,7 +42,7 @@ LMdeploy官方github地址:[https://github.com/InternLM/lmdeploy](https://github
...
@@ -32,7 +42,7 @@ LMdeploy官方github地址:[https://github.com/InternLM/lmdeploy](https://github
#### 编译环境准备
#### 编译环境准备
下载光源的镜像,起dcoker
下载光源的镜像,起dcoker
```
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:lmdeploy-dtk2310-torch1.13-py38
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:lmdeploy
1.0
-dtk23
.
10-torch1.13-py38
# <Image ID>用上面拉取docker镜像的ID替换
# <Image ID>用上面拉取docker镜像的ID替换
# <Host Path>主机端路径
# <Host Path>主机端路径
...
@@ -80,25 +90,25 @@ cd dist && pip3 install lmdeploy*
...
@@ -80,25 +90,25 @@ cd dist && pip3 install lmdeploy*
## 模型服务
## 模型服务
### 部署 [LLaMA](https://huggingface.co/huggyllama) 服务
### 模型转换
请从
[
这里
](
https://huggingface.co/huggyllama
)
下载 llama 模型,参考如下命令部署服务:
以7B为例:
```
```
1、模型转换
# <model_name> 模型的名字 ('llama', 'internlm', 'vicuna', 'wizardlM', 'internlm-chat-7b', 'internlm-chat', 'internlm-chat-7b-8k', 'internlm-chat-20b', 'internlm-20b', 'baichuan-7b', 'baichuan2-7b', 'puyu', 'llama2', 'qwen-7b', 'qwen-14b', 'qwen-72b', 'codellama', 'solar', 'ultralm', 'ultracm', 'yi')
# <model_name> 模型的名字 ('llama', 'internlm', 'vicuna', 'internlm-chat-7b', 'internlm-chat', 'internlm-chat-7b-8k', 'internlm-chat-20b', 'internlm-20b', 'baichuan-7b', 'baichuan2-7b', 'llama2', 'qwen-7b', 'qwen-14b',)
# <model_path> 模型路径
# <model_path> 模型路径
# <model_format> 模型的格式 ('llama', 'hf'
, 'qwen'
)
# <model_format> 模型的格式 ('llama', 'hf'
, None。可以不写默认None,代码会根据模型选择格式
)
# <tokenizer_path> tokenizer模型的路径(默认None,会去model_path里面找qwen.tiktoken)
# <tokenizer_path> tokenizer模型的路径(默认None,会去model_path里面找
对应的其他模型:'tokenizer.model',千问:'
qwen.tiktoken
'
)
# <model_format> 保存输出的目标路径(默认./workspace)
# <model_format> 保存输出的目标路径(默认./workspace)
# <tp> 用于张量并行的GPU数量应该是2^n
# <tp> 用于张量并行的GPU数量应该是2^n
lmdeploy convert --model_name llama --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_llama --tp 1
lmdeploy convert --model_name ${model_name} --model_path ${model_path} --model_format ${model_format} --tokenizer_path ${tokenizer_path} --dst_path ${dst_path} --tp ${tp}
```
2、运行
### 运行
# bash界面运行
#### bash界面运行
lmdeploy chat turbomind --model_path ./workspace_llama --tp 1 # 输入问题后执行2次回车进行推理
```
# <model_path>:转换后的模型路径
# 在服务器界面运行:
lmdeploy chat turbomind --model_path ${model_path} --tp ${tp} # tp要和转模型的tp保持一直 # 输入问题后执行2次回车进行推理
```
#### web页面方式交互:
```
在bash端运行:
在bash端运行:
# <model_path_or_server> 部署模型的路径或tritonserver URL或restful api URL。前者用于与gradio直接运行服务。后者用于默认情况下使用tritonserver运行。如果输入URL是restful api。请启用另一个标志“restful_api”。
# <model_path_or_server> 部署模型的路径或tritonserver URL或restful api URL。前者用于与gradio直接运行服务。后者用于默认情况下使用tritonserver运行。如果输入URL是restful api。请启用另一个标志“restful_api”。
# <server_name> gradio服务器的ip地址
# <server_name> gradio服务器的ip地址
...
@@ -107,95 +117,42 @@ lmdeploy chat turbomind --model_path ./workspace_llama --tp 1 # 输入问题
...
@@ -107,95 +117,42 @@ lmdeploy chat turbomind --model_path ./workspace_llama --tp 1 # 输入问题
# <tp> 用于张量并行的GPU数量应该是2^n (和模型转换的时候保持一致)
# <tp> 用于张量并行的GPU数量应该是2^n (和模型转换的时候保持一致)
# <restful_api> modelpath_or_server的标志(默认是False)
# <restful_api> modelpath_or_server的标志(默认是False)
lmdeploy serve gradio --model_path_or_server ./workspace_llama --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
lmdeploy serve gradio --model_path_or_server ${model_path_or_server} --server_name ${ip} --server_port ${pord} --batch_size 32 --tp ${tp}
在网页上输入{ip}:{pord}即可进行对话
```
### 部署 [llama2](https://huggingface.co/meta-llama) 服务
请从
[
这里
](
https://huggingface.co/meta-llama
)
下载 llama2 模型,参考如下命令部署服务:
以7B为例:
```
1、模型转换
lmdeploy convert --model_name llama2 --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_llama2 --tp 1 #
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_llama2 --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_llama2 --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
```
### 部署 [internlm](https://huggingface.co/internlm/) 服务
请从
[
这里
](
https://huggingface.co/internlm
)
下载 internlm 模型,参考如下命令部署服务:
以7B为例:
```
```
1、模型转换
在网页上输入{ip}:{pord}即可进行对话,
**需要保证'{ip}:{pord}'在外部浏览器中的可访问性**
lmdeploy convert --model_name model_name --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_intern --tp 1 # 根据模型的类型选择model_name是internlm-chat还是internlm
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_intern --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_intern --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
#### 使用api-server
```
启动server:
### 部署 [baichuan](https://huggingface.co/baichuan-inc) 服务
请从
[
这里
](
https://huggingface.co/baichuan-inc
)
下载 baichuan 模型,参考如下命令部署服务:
以7B为例:
```
```
1、模型转换
lmdeploy convert --model_name baichuan-7b --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_baichuan --tp 1
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_baichuan --tp 1
# 在服务器界面运行:
在bash端运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_baichuan --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
# --instance_num: turbomind推理实例的个数。模型支持的并发数 默认32
lmdeploy serve api_server ${model_path} --server_name ${server_ip} --server_port ${server_port} --instance_num ${instance_num} --tp ${tp}
在网页上输入{ip}:{pord}即可进行对话
```
```
用户将下面命令输出的 http url 复制到浏览器打开,详细查看所有的 API 及其使用方法。 请一定查看http://{server_ip}:{server_port}!!! 请一定查看http://{server_ip}:{server_port}!!! 请一定查看http://{server_ip}:{server_port}!!! 重要的事情说三遍。
### 部署 [baichuan2](https://huggingface.co/baichuan-inc) 服务
CLI client
请从
[
这里
](
https://huggingface.co/baichuan-inc
)
下载 baichuan2 模型,参考如下命令部署服务:
restful api 服务可以通过客户端测试,例如
以7B为例:
```
```
1、模型转换
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy convert --model_name baichuan2-7b --model_path /path/to/model --model_format hf --tokenizer_path None --dst_path ./workspace_baichuan2 --tp 1
lmdeploy serve api_client api_server_url
2、运行
# bash界面运行
lmdeploy chat turbomind --model_path ./workspace_baichuan2 --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_baichuan2 --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
```
```
webui
### 部署 [qwen](https://huggingface.co/Qwen) 服务
也可以直接用 webui 测试使用 restful-api。
请从
[
这里
](
https://huggingface.co/Qwen
)
下载 qwen 模型,参考如下命令部署服务:
以7B为例:
```
```
1、模型转换
# api_server_url 就是 api_server 产生的,比如 http://localhost:23333
lmdeploy convert --model_name qwen-7b --model_path /path/to/model --model_format qwen --tokenizer_path None --dst_path ./workspace_qwen --tp 1
# server_name 和 server_port 是用来提供 gradio ui 访问服务的
2、运行
# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
# bash界面运行
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
lmdeploy chat turbomind --model_path ./workspace_qwen --tp 1
# 在服务器界面运行:
在bash端运行:
lmdeploy serve gradio --model_path_or_server ./workspace_qwen --server_name {ip} --server_port {pord} --batch_size 32 --tp 1 --restful_api False
在网页上输入{ip}:{pord}即可进行对话
```
```
api-server的详细使用可以参照!
[
这里
](
docs/zh_cn/restful_api.md
)
的文档
codellama模型的部署可以参照!
[
codellama
](
docs/zh_cn/supported_models/codellama.md
)
## result
## result


### 详细可参考 [docs](./docs/zh_cn/serving.md)
### 详细可参考 [docs](./docs/zh_cn/serving.md)
## 版本号查询
## 版本号查询
-
python -c "import lmdeploy; lmdeploy.
\_\_
version__",版本号与官方版本同步,查询该软件的版本号,例如0.
0.6
;
-
python -c "import lmdeploy; lmdeploy.
\_\_
version__",版本号与官方版本同步,查询该软件的版本号,例如0.
1.0
;
## Known Issue
## Known Issue
-
无
-
无
...
...
benchmark/profile_generation.py
View file @
3253240a
...
@@ -10,10 +10,10 @@ from threading import Thread
...
@@ -10,10 +10,10 @@ from threading import Thread
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
from
pynvml
import
(
NVMLError
,
nvmlDeviceGetCount
,
nvmlDeviceGetHandleByIndex
,
#
from pynvml import (NVMLError, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo
,
nvmlDeviceGetName
,
#
nvmlDeviceGetMemoryInfo, nvmlDeviceGetName,
nvmlDeviceGetPowerState
,
nvmlDeviceGetTemperature
,
#
nvmlDeviceGetPowerState, nvmlDeviceGetTemperature,
nvmlInit
,
nvmlShutdown
,
nvmlSystemGetDriverVersion
)
#
nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lmdeploy.turbomind
import
TurboMind
from
lmdeploy.turbomind
import
TurboMind
...
@@ -186,76 +186,76 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
...
@@ -186,76 +186,76 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
percentiles
,
throughput
,
tm_model
.
gpu_count
percentiles
,
throughput
,
tm_model
.
gpu_count
class
MemoryMonitor
:
#
class MemoryMonitor:
from
multiprocessing
import
Manager
#
from multiprocessing import Manager
max_mem
=
Manager
().
Value
(
'f'
,
0
)
# GB
#
max_mem = Manager().Value('f', 0) # GB
device_count
=
Manager
().
Value
(
'f'
,
0
)
#
device_count = Manager().Value('f', 0)
@
staticmethod
#
@staticmethod
def
nvidia_info
():
#
def nvidia_info():
# pip install nvidia-ml-py
#
# pip install nvidia-ml-py
nvidia_dict
=
{
#
nvidia_dict = {
'state'
:
True
,
#
'state': True,
'nvidia_version'
:
''
,
#
'nvidia_version': '',
'nvidia_count'
:
0
,
#
'nvidia_count': 0,
'gpus'
:
[]
#
'gpus': []
}
#
}
try
:
#
try:
nvmlInit
()
#
nvmlInit()
nvidia_dict
[
'nvidia_version'
]
=
nvmlSystemGetDriverVersion
()
#
nvidia_dict['nvidia_version'] = nvmlSystemGetDriverVersion()
nvidia_dict
[
'nvidia_count'
]
=
nvmlDeviceGetCount
()
#
nvidia_dict['nvidia_count'] = nvmlDeviceGetCount()
for
i
in
range
(
nvidia_dict
[
'nvidia_count'
]):
#
for i in range(nvidia_dict['nvidia_count']):
handle
=
nvmlDeviceGetHandleByIndex
(
i
)
#
handle = nvmlDeviceGetHandleByIndex(i)
memory_info
=
nvmlDeviceGetMemoryInfo
(
handle
)
#
memory_info = nvmlDeviceGetMemoryInfo(handle)
gpu
=
{
#
gpu = {
'gpu_name'
:
nvmlDeviceGetName
(
handle
),
#
'gpu_name': nvmlDeviceGetName(handle),
'total'
:
memory_info
.
total
,
#
'total': memory_info.total,
'free'
:
memory_info
.
free
,
#
'free': memory_info.free,
'used'
:
memory_info
.
used
,
#
'used': memory_info.used,
'temperature'
:
f
'
{
nvmlDeviceGetTemperature
(
handle
,
0
)
}
℃'
,
#
'temperature': f'{nvmlDeviceGetTemperature(handle, 0)}℃',
'powerStatus'
:
nvmlDeviceGetPowerState
(
handle
)
#
'powerStatus': nvmlDeviceGetPowerState(handle)
}
#
}
nvidia_dict
[
'gpus'
].
append
(
gpu
)
#
nvidia_dict['gpus'].append(gpu)
except
NVMLError
as
_
:
# noqa
#
except NVMLError as _: # noqa
nvidia_dict
[
'state'
]
=
False
#
nvidia_dict['state'] = False
except
Exception
as
_
:
# noqa
#
except Exception as _: # noqa
nvidia_dict
[
'state'
]
=
False
#
nvidia_dict['state'] = False
finally
:
#
finally:
try
:
#
try:
nvmlShutdown
()
#
nvmlShutdown()
except
:
# noqa
#
except: # noqa
pass
#
pass
return
nvidia_dict
#
return nvidia_dict
@
classmethod
#
@classmethod
def
mem_monitor
(
cls
):
#
def mem_monitor(cls):
info
=
cls
.
nvidia_info
()
#
info = cls.nvidia_info()
max_mem
=
0
#
max_mem = 0
mem_start
=
0
#
mem_start = 0
cls
.
device_count
.
value
=
len
(
info
[
'gpus'
])
#
cls.device_count.value = len(info['gpus'])
for
used_total
in
info
[
'gpus'
]:
#
for used_total in info['gpus']:
mem_start
+=
used_total
[
'used'
]
#
mem_start += used_total['used']
while
True
:
#
while True:
info
=
cls
.
nvidia_info
()
#
info = cls.nvidia_info()
used
=
0
#
used = 0
for
used_total
in
info
[
'gpus'
]:
#
for used_total in info['gpus']:
used
+=
used_total
[
'used'
]
#
used += used_total['used']
if
used
>
max_mem
:
#
if used > max_mem:
max_mem
=
used
#
max_mem = used
cls
.
max_mem
.
value
=
(
max_mem
-
mem_start
)
/
(
1
<<
30
)
#
cls.max_mem.value = (max_mem - mem_start) / (1 << 30)
@
classmethod
#
@classmethod
def
start
(
cls
):
#
def start(cls):
cls
.
_running
=
True
#
cls._running = True
from
multiprocessing
import
Process
#
from multiprocessing import Process
cls
.
proc
=
Process
(
target
=
cls
.
mem_monitor
)
#
cls.proc = Process(target=cls.mem_monitor)
cls
.
proc
.
start
()
#
cls.proc.start()
@
classmethod
#
@classmethod
def
terminate
(
cls
)
->
float
:
#
def terminate(cls) -> float:
"""Terminate the subprocess and return maximum memory."""
#
"""Terminate the subprocess and return maximum memory."""
cls
.
proc
.
kill
()
#
cls.proc.kill()
return
cls
.
max_mem
.
value
#
return cls.max_mem.value
@
dataclass
@
dataclass
...
@@ -345,7 +345,7 @@ def main():
...
@@ -345,7 +345,7 @@ def main():
for
batch
in
args
.
concurrency
:
for
batch
in
args
.
concurrency
:
for
prompt_tokens
,
completion_tokens
in
zip
(
args
.
prompt_tokens
,
for
prompt_tokens
,
completion_tokens
in
zip
(
args
.
prompt_tokens
,
args
.
completion_tokens
):
args
.
completion_tokens
):
MemoryMonitor
.
start
()
#
MemoryMonitor.start()
from
functools
import
partial
from
functools
import
partial
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
profile_target
=
partial
(
profile_throughput
,
profile_target
=
partial
(
profile_throughput
,
...
@@ -362,8 +362,10 @@ def main():
...
@@ -362,8 +362,10 @@ def main():
model_name
,
first_token_latency
,
percentiles
,
\
model_name
,
first_token_latency
,
percentiles
,
\
throughput_per_proc
,
tp
=
output
[
0
]
throughput_per_proc
,
tp
=
output
[
0
]
time
.
sleep
(
5
)
# wait a while for releasing GPU mem
time
.
sleep
(
5
)
# wait a while for releasing GPU mem
memory
=
MemoryMonitor
.
terminate
()
# memory = MemoryMonitor.terminate()
device_count
=
MemoryMonitor
.
device_count
.
value
# device_count = MemoryMonitor.device_count.value
memory
=
0
device_count
=
0
results
.
append
(
results
.
append
(
ProfileResult
(
model_name
=
model_name
,
ProfileResult
(
model_name
=
model_name
,
batch
=
batch
,
batch
=
batch
,
...
...
examples/cpp/llama/CMakeLists.txt
View file @
3253240a
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
add_executable
(
llama_triton_example llama_triton_example.cc
)
add_executable
(
llama_triton_example llama_triton_example.cc
)
target_link_libraries
(
llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
#target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
target_link_libraries
(
llama_triton_example PUBLIC -lcublas -lrocblas -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list -lpthread
)
nvtx_utils word_list -lpthread
)
...
...
generate.sh
View file @
3253240a
...
@@ -10,4 +10,4 @@ cmake .. \
...
@@ -10,4 +10,4 @@ cmake .. \
-DBUILD_MULTI_GPU
=
ON
\
-DBUILD_MULTI_GPU
=
ON
\
-DCMAKE_CUDA_FLAGS
=
"-lineinfo"
\
-DCMAKE_CUDA_FLAGS
=
"-lineinfo"
\
-DUSE_NVTX
=
OFF
\
-DUSE_NVTX
=
OFF
\
#
-DBUILD_TEST=ON
-DBUILD_TEST
=
ON
src/turbomind/kernels/CMakeLists.txt
View file @
3253240a
...
@@ -73,3 +73,4 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
...
@@ -73,3 +73,4 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#add_subdirectory(gemm_s_f16)
#add_subdirectory(gemm_s_f16)
add_subdirectory
(
decoder_multihead_attention
)
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
View file @
3253240a
...
@@ -685,15 +685,15 @@ __device__ inline void m16n8k8(const uint32_t * A, const uint32_t * B, /*const f
...
@@ -685,15 +685,15 @@ __device__ inline void m16n8k8(const uint32_t * A, const uint32_t * B, /*const f
__builtin_memcpy
(
smem
+
(
base
+
2
),
B
,
sizeof
(
uint32_t
));
__builtin_memcpy
(
smem
+
(
base
+
2
),
B
,
sizeof
(
uint32_t
));
__syncthreads
();
__syncthreads
();
/* 站
在D的视角,每个进程负责D数据的计算,从0线程开始循环,获取一行A和两列
B
/* վ
��D���ӽǣ�ÿ�����̸���D���ݵļ��㣬��0�߳̿�ʼѭ������ȡһ��A������B
s为B
矩阵的线程号
sΪB
������̺߳�
baseA为A
的线程号
baseAΪA
���̺߳�
baseB0为
当前线程获取B的第一列,baseB1为当前线程获取B的第二列
baseB0Ϊ
��ǰ�̻߳�ȡB�ĵ�һ�У�baseB1Ϊ��ǰ�̻߳�ȡB�ĵڶ���
*/
*/
int
s
=
baseId
+
(
tid
%
4
)
*
8
,
e
=
s
+
4
;
int
s
=
baseId
+
(
tid
%
4
)
*
8
,
e
=
s
+
4
;
for
(
int
i
=
s
;
i
<
e
;
++
i
)
{
for
(
int
i
=
s
;
i
<
e
;
++
i
)
{
// A[0]->i A[1]->i+1 B[0]->i+2
// A[0]->i A[1]->i+1 B[0]->i+2
int
baseA
=
(
tid
-
tid
%
4
+
i
-
s
)
*
3
;
//
当
前tid
所处行的第一列的进程号
+stride
再
*3
int
baseB0
=
i
*
3
,
baseB1
=
(
i
+
4
)
*
3
;
int
baseB0
=
i
*
3
,
baseB1
=
(
i
+
4
)
*
3
;
f16mulf16addf32
(
smem
[
baseA
],
smem
[
baseB0
+
2
],
D
,
D
);
f16mulf16addf32
(
smem
[
baseA
],
smem
[
baseB0
+
2
],
D
,
D
);
...
@@ -1137,6 +1137,7 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
...
@@ -1137,6 +1137,7 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
return
int64
;
return
int64
;
}
}
#ifdef ENABLE_BF16
// bfloat16 to int8
// bfloat16 to int8
inline
__device__
int8_t
quant
(
__nv_bfloat16
a
,
const
float
scale
,
const
float
zp
)
inline
__device__
int8_t
quant
(
__nv_bfloat16
a
,
const
float
scale
,
const
float
zp
)
{
{
...
@@ -1184,6 +1185,7 @@ inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
...
@@ -1184,6 +1185,7 @@ inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
int16
[
3
]
=
quant
(
a
.
w
,
scale
,
zp
);
int16
[
3
]
=
quant
(
a
.
w
,
scale
,
zp
);
return
int64
;
return
int64
;
}
}
#endif
// int8 to float32, then `vec_conversion` to target format
// int8 to float32, then `vec_conversion` to target format
inline
__device__
float
dequant
(
int8_t
a
,
const
float
scale
,
const
float
zp
)
inline
__device__
float
dequant
(
int8_t
a
,
const
float
scale
,
const
float
zp
)
...
...
src/turbomind/kernels/decoder_masked_multihead_attention_utils.h
View file @
3253240a
...
@@ -326,7 +326,7 @@ inline __device__ float2 half2_to_float2(uint32_t v)
...
@@ -326,7 +326,7 @@ inline __device__ float2 half2_to_float2(uint32_t v)
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
bfloat16_to_float
(
__nv_bfloat16
h
)
inline
__device__
float
bfloat16_to_float
(
__nv_bfloat16
h
)
{
{
return
__bfloat162float
(
h
);
return
__bfloat162float
(
h
);
...
@@ -344,7 +344,7 @@ inline __device__ float2 bfloat162_to_float2(__nv_bfloat162 v)
...
@@ -344,7 +344,7 @@ inline __device__ float2 bfloat162_to_float2(__nv_bfloat162 v)
// asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
// asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
// return make_float2(bfloat16_to_float(lo), bfloat16_to_float(hi));
// return make_float2(bfloat16_to_float(lo), bfloat16_to_float(hi));
}
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
uint16_t
b
)
inline
__device__
float
add
(
float
a
,
uint16_t
b
)
...
...
src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
View file @
3253240a
...
@@ -5,12 +5,12 @@ add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv
...
@@ -5,12 +5,12 @@ add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
set_property
(
TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON
)
set_property
(
TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON
)
set_property
(
TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
set_property
(
TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
target_link_libraries
(
decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass
)
#
target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
add_executable
(
test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu
)
#
add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
# target_compile_options(test_decoder_multihead_attention PRIVATE
# target_compile_options(test_decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
target_link_libraries
(
test_decoder_multihead_attention PRIVATE
#
target_link_libraries(test_decoder_multihead_attention PRIVATE
decoder_multihead_attention
#
decoder_multihead_attention
decoder_masked_multihead_attention
#
decoder_masked_multihead_attention
cublas
)
#
cublas)
src/turbomind/kernels/decoder_multihead_attention/array_ops.h
View file @
3253240a
...
@@ -181,7 +181,8 @@ inline __device__ void Store(T* dst, const Array<T, N>& src)
...
@@ -181,7 +181,8 @@ inline __device__ void Store(T* dst, const Array<T, N>& src)
*
(
uint1
*
)
dst
=
(
const
uint1
&
)
src
;
*
(
uint1
*
)
dst
=
(
const
uint1
&
)
src
;
}
}
else
{
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
printf
(
"=====array_ops.h 184
\n
"
);
// static_assert(!std::is_same_v<T, T>);
}
}
}
}
...
@@ -200,7 +201,8 @@ inline __device__ void Ldg(Array<T, N>& dst, const T* src)
...
@@ -200,7 +201,8 @@ inline __device__ void Ldg(Array<T, N>& dst, const T* src)
(
uint
&
)
dst
=
__ldg
((
const
uint
*
)
src
);
(
uint
&
)
dst
=
__ldg
((
const
uint
*
)
src
);
}
}
else
{
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
printf
(
"=====array_ops.h 204
\n
"
);
// static_assert(!std::is_same_v<T, T>);
}
}
}
}
...
@@ -219,7 +221,8 @@ inline __device__ void Lds(Array<T, N>& dst, const T* src)
...
@@ -219,7 +221,8 @@ inline __device__ void Lds(Array<T, N>& dst, const T* src)
(
uint1
&
)
dst
=
*
(
const
uint1
*
)
src
;
(
uint1
&
)
dst
=
*
(
const
uint1
*
)
src
;
}
}
else
{
else
{
static_assert
(
!
std
::
is_same_v
<
T
,
T
>
);
printf
(
"=====array_ops.h 224
\n
"
);
// static_assert(!std::is_same_v<T, T>);
}
}
}
}
...
@@ -377,7 +380,15 @@ struct ConvertKvCache<Ti, int8_t> {
...
@@ -377,7 +380,15 @@ struct ConvertKvCache<Ti, int8_t> {
inline
__device__
uint8_t
round
(
float
x
)
const
inline
__device__
uint8_t
round
(
float
x
)
const
{
{
uint32_t
y
;
uint32_t
y
;
asm
(
"cvt.rni.sat.u8.f32 %0, %1;
\n
"
:
"=r"
(
y
)
:
"f"
(
x
));
printf
(
"======arrat_ops.h 380
\n
"
);
// asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
if
(
x
>=
255
)
{
y
=
255
;
}
else
if
(
x
<
0
)
{
y
=
0
;
}
else
{
y
=
std
::
round
(
x
);
}
return
y
;
return
y
;
}
}
...
@@ -414,11 +425,11 @@ inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
...
@@ -414,11 +425,11 @@ inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
static
constexpr
uint32_t
m1
=
0x7614
;
static
constexpr
uint32_t
m1
=
0x7614
;
static
constexpr
uint32_t
m2
=
0x7624
;
static
constexpr
uint32_t
m2
=
0x7624
;
static
constexpr
uint32_t
m3
=
0x7634
;
static
constexpr
uint32_t
m3
=
0x7634
;
printf
(
"======arrat_ops.h 417
\n
"
);
asm
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
u32x4
[
0
])
:
"r"
(
i8s
),
"n"
(
f32_magic
),
"n"
(
m0
));
//
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
asm
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
u32x4
[
1
])
:
"r"
(
i8s
),
"n"
(
f32_magic
),
"n"
(
m1
));
//
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
asm
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
u32x4
[
2
])
:
"r"
(
i8s
),
"n"
(
f32_magic
),
"n"
(
m2
));
//
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
asm
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
u32x4
[
3
])
:
"r"
(
i8s
),
"n"
(
f32_magic
),
"n"
(
m3
));
//
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
if
(
0
)
{
// fused with dequantization
if
(
0
)
{
// fused with dequantization
PRAGMA_UNROLL
PRAGMA_UNROLL
...
...
src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
View file @
3253240a
...
@@ -25,7 +25,8 @@ struct DecoderMultiHeadAttentionParams {
...
@@ -25,7 +25,8 @@ struct DecoderMultiHeadAttentionParams {
const
float
*
__restrict__
rope_theta
;
const
float
*
__restrict__
rope_theta
;
// kv cache
// kv cache
size_t
layer_offset
;
// size_t layer_offset;
int
layer_offset
;
/// cache layout M,[N,H,x,D]
/// cache layout M,[N,H,x,D]
/// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
/// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
...
...
src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
View file @
3253240a
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include <climits>
#include <climits>
#include <cmath>
#include <cmath>
#include <cstdint>
#include <cstdint>
#include <cuda_pipeline_primitives.h>
//
#include <cuda_pipeline_primitives.h>
#include <type_traits>
#include <type_traits>
#include "decoder_multihead_attention_params.h"
#include "decoder_multihead_attention_params.h"
...
@@ -92,8 +92,10 @@ struct DecoderMultiHeadAttentionKernel {
...
@@ -92,8 +92,10 @@ struct DecoderMultiHeadAttentionKernel {
Tkv
*
__restrict__
k_cache_
;
// [S, D]
Tkv
*
__restrict__
k_cache_
;
// [S, D]
Tkv
*
__restrict__
v_cache_
;
// [S, D]
Tkv
*
__restrict__
v_cache_
;
// [S, D]
const
void
**
__restrict__
k_cache_ptrs_
;
// const void** __restrict__ k_cache_ptrs_;
const
void
**
__restrict__
v_cache_ptrs_
;
// const void** __restrict__ v_cache_ptrs_;
void
**
__restrict__
k_cache_ptrs_
;
void
**
__restrict__
v_cache_ptrs_
;
Tkv
*
__restrict__
smem_Kv_
;
Tkv
*
__restrict__
smem_Kv_
;
float
*
__restrict__
smem_S_
;
float
*
__restrict__
smem_S_
;
...
@@ -325,18 +327,18 @@ struct DecoderMultiHeadAttentionKernel {
...
@@ -325,18 +327,18 @@ struct DecoderMultiHeadAttentionKernel {
__device__
void
CpAsyncWait
()
__device__
void
CpAsyncWait
()
{
{
__pipeline_wait_prior
(
kStages
-
2
);
//
__pipeline_wait_prior(kStages - 2);
}
}
__device__
void
CpAsyncCommit
()
__device__
void
CpAsyncCommit
()
{
{
__pipeline_commit
();
//
__pipeline_commit();
}
}
__device__
void
CpAsyncFlush
()
__device__
void
CpAsyncFlush
()
{
{
__pipeline_commit
();
//
__pipeline_commit();
__pipeline_wait_prior
(
0
);
//
__pipeline_wait_prior(0);
}
}
static
constexpr
int
kKvVecPerThread
=
MapKv
::
kIterC
;
static
constexpr
int
kKvVecPerThread
=
MapKv
::
kIterC
;
...
...
src/turbomind/kernels/decoder_multihead_attention/iterator.h
View file @
3253240a
...
@@ -14,12 +14,15 @@ namespace turbomind {
...
@@ -14,12 +14,15 @@ namespace turbomind {
#endif
#endif
struct
BlockIterator
{
struct
BlockIterator
{
const
void
**
ptrs_
;
// const void** ptrs_;
const
void
*
prefetch_
;
// const void* prefetch_;
void
**
ptrs_
;
void
*
prefetch_
;
BlockIterator
()
=
default
;
BlockIterator
()
=
default
;
__device__
BlockIterator
(
const
void
**
block_ptrs
)
:
ptrs_
{
block_ptrs
}
__device__
BlockIterator
(
/*
const
*/
void
**
block_ptrs
)
:
ptrs_
{
block_ptrs
}
{
{
// prefetch first ptr
// prefetch first ptr
prefetch_
=
*
ptrs_
++
;
prefetch_
=
*
ptrs_
++
;
...
@@ -111,7 +114,8 @@ struct Iterator {
...
@@ -111,7 +114,8 @@ struct Iterator {
is_valid_s_
=
offset_s_
<
seq_len
;
is_valid_s_
=
offset_s_
<
seq_len
;
}
}
__device__
Iterator
(
const
void
**
block_ptrs
,
// __device__ Iterator(const void** block_ptrs,
__device__
Iterator
(
void
**
block_ptrs
,
int
block_size
,
int
block_size
,
int
layer_offset
,
int
layer_offset
,
int
head_idx
,
int
head_idx
,
...
@@ -258,25 +262,26 @@ struct Iterator {
...
@@ -258,25 +262,26 @@ struct Iterator {
}
}
#endif
#endif
static
__device__
void
CpAsync
(
T
*
__restrict__
dst
,
const
T
*
__restrict__
src
,
bool
mask
)
// static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{
// {
const
int
smem_int_ptr
=
cast_smem_ptr_to_uint
(
dst
);
// const int smem_int_ptr = cast_smem_ptr_to_uint(dst);
constexpr
int
cp_size
=
sizeof
(
AccessType
);
// constexpr int cp_size = sizeof(AccessType);
#if TURBOMIND_ARCH_SM80
// printf("======iterator.h 265\n");
// clang-format off
// #if TURBOMIND_ARCH_SM80
asm
volatile
(
"{
\n
"
// // clang-format off
" .reg .pred p;
\n
"
// asm volatile("{\n"
" setp.ne.b32 p, %0, 0;
\n
"
// " .reg .pred p;\n"
" @p cp.async.ca.shared.global"
L2_CACHEHINT
(
128
)
" [%1], [%2], %3;
\n
"
// " setp.ne.b32 p, %0, 0;\n"
"}
\n
"
::
"r"
((
int
)
mask
),
// " @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"r"
(
smem_int_ptr
),
// "}\n" ::"r"((int)mask),
"l"
(
src
),
// "r"(smem_int_ptr),
"n"
(
cp_size
));
// "l"(src),
// clang-format on
// "n"(cp_size));
#else
// // clang-format on
assert
(
TURBOMIND_ARCH_SM80
);
// #else
#endif
// assert(TURBOMIND_ARCH_SM80);
}
// #endif
// }
static
__device__
void
Copy
(
T
*
__restrict__
dst
,
const
T
*
__restrict__
src
,
bool
mask
)
static
__device__
void
Copy
(
T
*
__restrict__
dst
,
const
T
*
__restrict__
src
,
bool
mask
)
{
{
...
@@ -287,12 +292,12 @@ struct Iterator {
...
@@ -287,12 +292,12 @@ struct Iterator {
__device__
void
Prefetch
(
bool
mask
)
__device__
void
Prefetch
(
bool
mask
)
{
{
if
constexpr
(
TURBOMIND_ARCH_SM80
)
{
//
if constexpr (TURBOMIND_ARCH_SM80) {
CpAsync
(
smem_
+
dst_offset_
,
src_
+
src_offset_
,
mask
);
//
CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
}
//
}
else
{
//
else {
Copy
(
smem_
+
dst_offset_
,
src_
+
src_offset_
,
mask
);
Copy
(
smem_
+
dst_offset_
,
src_
+
src_offset_
,
mask
);
}
//
}
}
}
__device__
void
Load
(
AccessType
(
&
frag
)[
ThreadMap
::
kIterC
])
__device__
void
Load
(
AccessType
(
&
frag
)[
ThreadMap
::
kIterC
])
...
...
src/turbomind/kernels/gemm_s_f16/common.h
View file @
3253240a
...
@@ -14,19 +14,20 @@ namespace turbomind {
...
@@ -14,19 +14,20 @@ namespace turbomind {
#define TURBOMIND_S4_DEQUANT_USE_FMA 0
#define TURBOMIND_S4_DEQUANT_USE_FMA 0
#endif
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
//
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
#define TURBOMIND_ARCH_SM75 1
//
#define TURBOMIND_ARCH_SM75 1
#else
//
#else
#define TURBOMIND_ARCH_SM75 0
//
#define TURBOMIND_ARCH_SM75 0
#endif
//
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
//
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define TURBOMIND_ARCH_SM80 1
//
#define TURBOMIND_ARCH_SM80 1
#else
//
#else
#define TURBOMIND_ARCH_SM80 0
//
#define TURBOMIND_ARCH_SM80 0
#endif
//
#endif
constexpr
int
WARP_SIZE
=
32
;
// constexpr int WARP_SIZE = 32;
constexpr
int
WARP_SIZE
=
64
;
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
...
@@ -68,22 +69,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -68,22 +69,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// dependency if we issue immediately before required.
// dependency if we issue immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
const
uint32_t
top_i4s
=
i4s
>>
8
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
:
"=r"
(
h
[
0
])
//
: "=r"(h[0])
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
//
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
//
//
Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
:
"=r"
(
h
[
1
])
//
: "=r"(h[1])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
//
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
//
//
Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
:
"=r"
(
h
[
2
])
//
: "=r"(h[2])
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
//
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
//
//
Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
:
"=r"
(
h
[
3
])
//
: "=r"(h[3])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
//
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf
(
"=========common.h 86
\n
"
);
// I use inline PTX below because I am not sure if the compiler will emit
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// performance reliability over code readability.
...
@@ -101,13 +102,13 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -101,13 +102,13 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers.
// Finally, we construct the output numbers.
// Convert elt_01
// Convert elt_01
asm
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
//
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
//
//
Convert elt_23
asm
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
//
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
//
//
Convert elt_45
asm
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
//
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
//
//
Convert elt_67
asm
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
//
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return
result
;
return
result
;
}
}
...
@@ -130,27 +131,27 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
...
@@ -130,27 +131,27 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
// dependency if we issue immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
const
uint32_t
top_i4s
=
i4s
>>
8
;
printf
(
"=========common.h 133
\n
"
);
if
(
0
)
{
// 1024 & 64
//
if (0) { // 1024 & 64
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOT_MASK
),
"n"
(
MAGIC_NUM_0
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
MAGIC_NUM_1
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOT_MASK
),
"n"
(
MAGIC_NUM_0
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
MAGIC_NUM_1
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
MAGIC_NUM_0
));
//
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
asm
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
MAGIC_NUM_1
));
//
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
asm
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
MAGIC_NUM_0
));
//
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
asm
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
MAGIC_NUM_1
));
//
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
}
//
}
else
{
// 64 only, trade 4 hfma2 with 2 shifts
//
else { // 64 only, trade 4 hfma2 with 2 shifts
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOT_MASK
),
"n"
(
MAGIC_NUM_2
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
MAGIC_NUM_1
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOT_MASK
),
"n"
(
MAGIC_NUM_2
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
MAGIC_NUM_1
),
"n"
(
immLut
));
//
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h
[
0
]
<<=
4
;
//
h[0] <<= 4;
h
[
2
]
<<=
4
;
//
h[2] <<= 4;
// we don't need to subtract the magic nums because zeros will go through the same dequant function
//
// we don't need to subtract the magic nums because zeros will go through the same dequant function
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
//
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
}
//
}
return
result
;
return
result
;
}
}
...
@@ -158,62 +159,64 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
...
@@ -158,62 +159,64 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
{
uint32_t
smem_int_ptr
;
uint32_t
smem_int_ptr
;
printf
(
"=========common.h 161
\n
"
);
asm
(
"{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }
\n
"
//
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
:
"=r"
(
smem_int_ptr
)
//
: "=r"(smem_int_ptr)
:
"l"
(
ptr
));
//
: "l"(ptr));
return
smem_int_ptr
;
return
smem_int_ptr
;
}
}
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
uint
&
d0
,
uint
&
d1
,
uint
&
d2
,
uint
&
d3
,
uint32_t
smem_int_ptr
)
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
uint
&
d0
,
uint
&
d1
,
uint
&
d2
,
uint
&
d3
,
uint32_t
smem_int_ptr
)
{
{
#if TURBOMIND_ARCH_SM75
printf
(
"=========common.h 171
\n
"
);
asm
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
// #if TURBOMIND_ARCH_SM75
:
"=r"
(
d0
),
"=r"
(
d1
),
"=r"
(
d2
),
"=r"
(
d3
)
// asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
:
"r"
(
smem_int_ptr
));
// : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
#else
// : "r"(smem_int_ptr));
assert
(
TURBOMIND_ARCH_SM75
);
// #else
#endif
// assert(TURBOMIND_ARCH_SM75);
// #endif
}
}
__inline__
__device__
void
ldmatrix_m8n8_x2_b16
(
uint
&
d0
,
uint
&
d1
,
uint32_t
smem_int_ptr
)
__inline__
__device__
void
ldmatrix_m8n8_x2_b16
(
uint
&
d0
,
uint
&
d1
,
uint32_t
smem_int_ptr
)
{
{
#if TURBOMIND_ARCH_SM75
printf
(
"=========common.h 183
\n
"
);
asm
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];
\n
"
:
"=r"
(
d0
),
"=r"
(
d1
)
:
"r"
(
smem_int_ptr
));
// #if TURBOMIND_ARCH_SM75
#else
// asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
assert
(
TURBOMIND_ARCH_SM75
);
// #else
#endif
// assert(TURBOMIND_ARCH_SM75);
}
// #endif
__inline__
__device__
void
wait_flag
(
int
*
lock
,
int
status
,
int
thread_id
)
{
int
state
=
0
;
while
(
__syncthreads_and
(
state
!=
status
))
{
if
(
thread_id
==
0
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
#else
asm
volatile
(
"ld.global.cg.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
#endif
}
}
__syncthreads
();
// memory fence
}
}
__inline__
__device__
void
release_flag
(
int
*
lock
,
int
status
,
int
thread_id
)
// __inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
{
// {
__syncthreads
();
// memory fence
// int state = 0;
// while (__syncthreads_and(state != status)) {
if
(
thread_id
==
0
)
{
// if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"st.global.release.gpu.b32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
status
));
// asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
// #else
asm
volatile
(
"st.global.cg.b32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
status
));
// asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
// #endif
}
// }
}
// }
// __syncthreads(); // memory fence
// }
// __inline__ __device__ void release_flag(int* lock, int status, int thread_id)
// {
// __syncthreads(); // memory fence
// if (thread_id == 0) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
// asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
// #else
// asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
// #endif
// }
// }
__inline__
__device__
half2
apply_Q
(
const
half2
&
x
,
const
half2
&
q
)
__inline__
__device__
half2
apply_Q
(
const
half2
&
x
,
const
half2
&
q
)
{
{
...
@@ -223,14 +226,14 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
...
@@ -223,14 +226,14 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
auto
&
t
=
(
const
uint
&
)
x
;
auto
&
t
=
(
const
uint
&
)
x
;
uint
u
,
v
;
uint
u
,
v
;
if
(
TURBOMIND_S4_DEQUANT_USE_FMA
)
{
//
if (TURBOMIND_S4_DEQUANT_USE_FMA) {
asm
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
v
)
:
"r"
(
t
),
"r"
(
s
),
"r"
(
z
));
//
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
}
//
}
else
{
//
else {
asm
(
"sub.ftz.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
u
)
:
"r"
(
t
),
"r"
(
z
));
//
asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
asm
(
"mul.ftz.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
v
)
:
"r"
(
u
),
"r"
(
s
));
//
asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
}
//
}
printf
(
"=========common.h 235
\n
"
);
return
(
half2
&
)
v
;
return
(
half2
&
)
v
;
}
}
...
...
src/turbomind/kernels/gemm_s_f16/cta_iterator.h
View file @
3253240a
...
@@ -8,73 +8,73 @@
...
@@ -8,73 +8,73 @@
namespace
turbomind
{
namespace
turbomind
{
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
//
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
//
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
//
#else
#define L2_CACHEHINT(size)
//
#define L2_CACHEHINT(size)
#endif
//
#endif
template
<
typename
T
>
//
template<typename T>
__inline__
__device__
void
cp_async_cg_A
(
uint32_t
smem_int_ptr
,
const
T
*
__restrict__
src
,
bool
mask
)
//
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
//
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
constexpr
int
cp_size
=
sizeof
(
T
);
//
constexpr int cp_size = sizeof(T);
static_assert
(
cp_size
==
16
,
"cp.async.cg requreis cp_size == 16"
);
//
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
// clang-format off
//
// clang-format off
asm
volatile
(
"{
\n
"
//
asm volatile("{\n"
" .reg .pred p;
\n
"
//
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;
\n
"
//
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global"
L2_CACHEHINT
(
256
)
" [%1], [%2], %3;
\n
"
//
" @p cp.async.cg.shared.global" L2_CACHEHINT(256) " [%1], [%2], %3;\n"
"}
\n
"
::
"r"
((
int
)
mask
),
//
"}\n" ::"r"((int)mask),
"r"
(
smem_int_ptr
),
//
"r"(smem_int_ptr),
"l"
(
src
),
//
"l"(src),
"n"
(
cp_size
));
//
"n"(cp_size));
// clang-format on
//
// clang-format on
#else
//
#else
assert
(
TURBOMIND_ARCH_SM80
);
//
assert(TURBOMIND_ARCH_SM80);
#endif
//
#endif
}
//
}
template
<
typename
T
>
//
template<typename T>
__inline__
__device__
void
cp_async_cg_B
(
uint32_t
smem_int_ptr
,
const
T
*
__restrict__
src
,
bool
mask
)
//
__inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
//
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
constexpr
int
cp_size
=
sizeof
(
T
);
//
constexpr int cp_size = sizeof(T);
static_assert
(
cp_size
==
16
,
"cp.async.cg requreis cp_size == 16"
);
//
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
// clang-format off
//
// clang-format off
asm
volatile
(
"{
\n
"
//
asm volatile("{\n"
" .reg .pred p;
\n
"
//
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;
\n
"
//
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global"
L2_CACHEHINT
(
128
)
" [%1], [%2], %3;
\n
"
//
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}
\n
"
::
"r"
((
int
)
mask
),
//
"}\n" ::"r"((int)mask),
"r"
(
smem_int_ptr
),
//
"r"(smem_int_ptr),
"l"
(
src
),
//
"l"(src),
"n"
(
cp_size
));
//
"n"(cp_size));
// clang-format on
//
// clang-format on
#else
//
#else
assert
(
TURBOMIND_ARCH_SM80
);
//
assert(TURBOMIND_ARCH_SM80);
#endif
//
#endif
}
//
}
template
<
typename
T
>
//
template<typename T>
__inline__
__device__
void
cp_async_ca
(
uint32_t
smem_int_ptr
,
const
T
*
__restrict__
src
,
bool
mask
)
//
__inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
//
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
constexpr
int
cp_size
=
sizeof
(
T
);
//
constexpr int cp_size = sizeof(T);
// clang-format off
//
// clang-format off
asm
volatile
(
"{
\n
"
//
asm volatile("{\n"
" .reg .pred p;
\n
"
//
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;
\n
"
//
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global"
L2_CACHEHINT
(
128
)
" [%1], [%2], %3;
\n
"
//
" @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}
\n
"
::
"r"
((
int
)
mask
),
//
"}\n" ::"r"((int)mask),
"r"
(
smem_int_ptr
),
//
"r"(smem_int_ptr),
"l"
(
src
),
//
"l"(src),
"n"
(
cp_size
));
//
"n"(cp_size));
// clang-format on
//
// clang-format on
#else
//
#else
assert
(
TURBOMIND_ARCH_SM80
);
//
assert(TURBOMIND_ARCH_SM80);
#endif
//
#endif
}
//
}
template
<
int
WARPS
,
int
CTA_M
,
int
CTA_N
,
int
CTA_K
,
int
STAGES
,
int
SLICES
>
template
<
int
WARPS
,
int
CTA_M
,
int
CTA_N
,
int
CTA_K
,
int
STAGES
,
int
SLICES
>
struct
IteratorA
{
struct
IteratorA
{
...
@@ -237,13 +237,13 @@ struct IteratorA {
...
@@ -237,13 +237,13 @@ struct IteratorA {
__device__
void
prefetch
(
bool
mask
)
__device__
void
prefetch
(
bool
mask
)
{
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
cp_async_cg_A
(
smem_int_ptr_
+
dst_offset_
,
(
const
AccessType
*
)
src_
+
src_offset_
,
mask
);
//
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
//
#else
if
(
mask
)
{
if
(
mask
)
{
*
(
AccessType
*
)((
uint8_t
*
)
smem_
+
dst_offset_
)
=
__ldg
((
const
AccessType
*
)
src_
+
src_offset_
);
*
(
AccessType
*
)((
uint8_t
*
)
smem_
+
dst_offset_
)
=
__ldg
((
const
AccessType
*
)
src_
+
src_offset_
);
}
}
#endif
//
#endif
}
}
};
};
...
@@ -424,13 +424,13 @@ struct IteratorQ {
...
@@ -424,13 +424,13 @@ struct IteratorQ {
__device__
void
prefetch
(
bool
mask
)
__device__
void
prefetch
(
bool
mask
)
{
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
cp_async_ca
(
smem_int_ptr_
+
dst_offset_
,
(
const
AccessType
*
)
src_
+
src_offset_
,
mask
);
//
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
//
#else
if
(
mask
)
{
if
(
mask
)
{
*
(
AccessType
*
)((
uint8_t
*
)
smem_
+
dst_offset_
)
=
__ldg
((
const
AccessType
*
)
src_
+
src_offset_
);
*
(
AccessType
*
)((
uint8_t
*
)
smem_
+
dst_offset_
)
=
__ldg
((
const
AccessType
*
)
src_
+
src_offset_
);
}
}
#endif
//
#endif
}
}
};
};
...
@@ -626,14 +626,14 @@ struct IteratorB {
...
@@ -626,14 +626,14 @@ struct IteratorB {
__device__
void
prefetch
(
bool
mask
)
__device__
void
prefetch
(
bool
mask
)
{
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
cp_async_cg_B
(
//
cp_async_cg_B(
smem_int_ptr_
+
tmp_dst_offset_
,
(
const
AccessType
*
)(
src_
+
tmp_src_offset_
),
is_valid_n_
&&
mask
);
//
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
#else
//
#else
if
(
is_valid_n_
&&
mask
)
{
if
(
is_valid_n_
&&
mask
)
{
*
(
AccessType
*
)((
uint8_t
*
)
smem_
+
tmp_dst_offset_
)
=
__ldg
((
const
AccessType
*
)(
src_
+
tmp_src_offset_
));
*
(
AccessType
*
)((
uint8_t
*
)
smem_
+
tmp_dst_offset_
)
=
__ldg
((
const
AccessType
*
)(
src_
+
tmp_src_offset_
));
}
}
#endif
//
#endif
}
}
};
};
...
...
src/turbomind/kernels/gemm_s_f16/gemm_template.h
View file @
3253240a
...
@@ -9,41 +9,41 @@
...
@@ -9,41 +9,41 @@
namespace
turbomind
{
namespace
turbomind
{
__inline__
__device__
void
//
__inline__ __device__ void
mma_m16n8k8_row_col
(
Array
<
float
,
4
>&
d
,
const
Array
<
half
,
4
>&
a
,
const
Array
<
half
,
2
>&
b
,
Array
<
float
,
4
>&
c
)
//
mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)
{
//
{
#if TURBOMIND_ARCH_SM75
//
#if TURBOMIND_ARCH_SM75
uint32_t
const
*
A
=
reinterpret_cast
<
uint32_t
const
*>
(
&
a
);
//
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t
const
*
B
=
reinterpret_cast
<
uint32_t
const
*>
(
&
b
);
//
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float
const
*
C
=
reinterpret_cast
<
float
const
*>
(
&
c
);
//
float const* C = reinterpret_cast<float const*>(&c);
float
*
D
=
reinterpret_cast
<
float
*>
(
&
d
);
//
float* D = reinterpret_cast<float*>(&d);
asm
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
//
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
//
"{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
:
"=f"
(
D
[
0
]),
"=f"
(
D
[
1
]),
"=f"
(
D
[
2
]),
"=f"
(
D
[
3
])
//
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"
(
A
[
0
]),
"r"
(
A
[
1
]),
"r"
(
B
[
0
]),
"f"
(
C
[
0
]),
"f"
(
C
[
1
]),
"f"
(
C
[
2
]),
"f"
(
C
[
3
]));
//
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
//
#else
assert
(
TURBOMIND_ARCH_SM75
);
//
assert(TURBOMIND_ARCH_SM75);
#endif
//
#endif
}
//
}
__inline__
__device__
void
__inline__
__device__
void
mma_m16n8k16_row_col
(
Array
<
float
,
4
>&
d
,
const
Array
<
half
,
8
>&
a
,
const
Array
<
half
,
4
>&
b
,
Array
<
float
,
4
>&
c
)
mma_m16n8k16_row_col
(
Array
<
float
,
4
>&
d
,
const
Array
<
half
,
8
>&
a
,
const
Array
<
half
,
4
>&
b
,
Array
<
float
,
4
>&
c
)
{
{
#if TURBOMIND_ARCH_SM80
//
#if TURBOMIND_ARCH_SM80
uint32_t
const
*
A
=
reinterpret_cast
<
uint32_t
const
*>
(
&
a
);
//
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t
const
*
B
=
reinterpret_cast
<
uint32_t
const
*>
(
&
b
);
//
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float
const
*
C
=
reinterpret_cast
<
float
const
*>
(
&
c
);
//
float const* C = reinterpret_cast<float const*>(&c);
float
*
D
=
reinterpret_cast
<
float
*>
(
&
d
);
//
float* D = reinterpret_cast<float*>(&d);
asm
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
//
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
//
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
:
"=f"
(
D
[
0
]),
"=f"
(
D
[
1
]),
"=f"
(
D
[
2
]),
"=f"
(
D
[
3
])
//
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"
(
A
[
0
]),
"r"
(
A
[
1
]),
"r"
(
A
[
2
]),
"r"
(
A
[
3
]),
"r"
(
B
[
0
]),
"r"
(
B
[
1
]),
"f"
(
C
[
0
]),
"f"
(
C
[
1
]),
"f"
(
C
[
2
]),
"f"
(
C
[
3
]));
//
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
//
#else
const
Array
<
half
,
4
>*
_a
=
(
const
Array
<
half
,
4
>*
)
&
a
;
//
const Array<half, 4>* _a = (const Array<half, 4>*)&a;
const
Array
<
half
,
2
>*
_b
=
(
const
Array
<
half
,
2
>*
)
&
b
;
//
const Array<half, 2>* _b = (const Array<half, 2>*)&b;
mma_m16n8k8_row_col
(
d
,
_a
[
0
],
_b
[
0
],
c
);
//
mma_m16n8k8_row_col(d, _a[0], _b[0], c);
mma_m16n8k8_row_col
(
d
,
_a
[
1
],
_b
[
1
],
d
);
//
mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
//
#endif
}
}
__inline__
__device__
uint
transpose_m8n8_b16_warp_shuffle
(
uint
value
,
int
lane_id
)
__inline__
__device__
uint
transpose_m8n8_b16_warp_shuffle
(
uint
value
,
int
lane_id
)
...
@@ -64,29 +64,29 @@ __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_
...
@@ -64,29 +64,29 @@ __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_
return
(
uint
&
)
r
;
return
(
uint
&
)
r
;
}
}
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
//
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
__inline__
__device__
uint
transpose_m8n8_b16_movmatrix
(
uint
a
)
//
__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
{
//
{
#if TURBOMIND_ARCH_SM75
//
#if TURBOMIND_ARCH_SM75
uint
d
;
//
uint d;
asm
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
));
//
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a));
return
d
;
//
return d;
#else
//
#else
assert
(
TURBOMIND_ARCH_SM75
);
//
assert(TURBOMIND_ARCH_SM75);
return
0
;
//
return 0;
#endif
//
#endif
}
//
}
#endif
//
#endif
__inline__
__device__
uint
transpose_m8n8_b16
(
uint
a
,
int
lane_id
)
__inline__
__device__
uint
transpose_m8n8_b16
(
uint
a
,
int
lane_id
)
{
{
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
//
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(
void
)
lane_id
;
//
(void)lane_id;
return
transpose_m8n8_b16_movmatrix
(
a
);
//
return transpose_m8n8_b16_movmatrix(a);
#else
//
#else
return
transpose_m8n8_b16_warp_shuffle
(
a
,
lane_id
);
//
return transpose_m8n8_b16_warp_shuffle(a, lane_id);
#endif
//
#endif
}
}
namespace
ops
{
namespace
ops
{
...
@@ -242,7 +242,7 @@ struct Gemm {
...
@@ -242,7 +242,7 @@ struct Gemm {
constexpr
int
SLICE_GROUP
=
(
SLICES
+
7
)
/
8
;
constexpr
int
SLICE_GROUP
=
(
SLICES
+
7
)
/
8
;
constexpr
uint32_t
num_threads
=
kWarpCountMN
*
WARP_SIZE
;
constexpr
uint32_t
num_threads
=
kWarpCountMN
*
WARP_SIZE
;
const
uint32_t
barrier_id
=
slice_id
/
SLICE_GROUP
+
1
;
const
uint32_t
barrier_id
=
slice_id
/
SLICE_GROUP
+
1
;
asm
volatile
(
"bar.sync %0, %1;"
:
:
"r"
(
barrier_id
),
"n"
(
num_threads
));
//
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
}
}
...
...
src/turbomind/models/llama/CMakeLists.txt
View file @
3253240a
...
@@ -50,7 +50,7 @@ if (NOT MSVC)
...
@@ -50,7 +50,7 @@ if (NOT MSVC)
endif
()
endif
()
add_executable
(
llama_gemm llama_gemm.cc
)
add_executable
(
llama_gemm llama_gemm.cc
)
target_link_libraries
(
llama_gemm PUBLIC cudart gpt_gemm_func memory_utils cuda_utils logger
)
target_link_libraries
(
llama_gemm PUBLIC
-lrocblas
cudart gpt_gemm_func memory_utils cuda_utils logger
)
install
(
TARGETS llama_gemm DESTINATION
${
CMAKE_SOURCE_DIR
}
/lmdeploy/bin
)
install
(
TARGETS llama_gemm DESTINATION
${
CMAKE_SOURCE_DIR
}
/lmdeploy/bin
)
find_package
(
Catch2 3 QUIET
)
find_package
(
Catch2 3 QUIET
)
...
...
src/turbomind/models/llama/llama_decoder_kernels.cu
View file @
3253240a
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
//
#include <cooperative_groups/reduce.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
...
...
src/turbomind/models/llama/llama_kernels.cu
View file @
3253240a
...
@@ -962,8 +962,8 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud
...
@@ -962,8 +962,8 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud
// template class FlashAttentionOp<float>;
// template class FlashAttentionOp<float>;
// template class FlashAttentionOp<half>;
// template class FlashAttentionOp<half>;
#ifdef ENABLE_BF16
//
#ifdef ENABLE_BF16
template
class
FlashAttentionOp
<
__nv_bfloat16
>;
//
template class FlashAttentionOp<__nv_bfloat16>;
#endif
//
#endif
}
// namespace turbomind
}
// namespace turbomind
src/turbomind/models/llama/unified_attention_layer.cc
View file @
3253240a
...
@@ -64,7 +64,8 @@ void UnifiedAttentionLayer<T>::allocateBuffer(size_t num_token,
...
@@ -64,7 +64,8 @@ void UnifiedAttentionLayer<T>::allocateBuffer(size_t num_token,
k_buf_2_
=
q_buf_2_
+
local_head_num_
*
bsz
*
max_q
*
size_per_head_
;
k_buf_2_
=
q_buf_2_
+
local_head_num_
*
bsz
*
max_q
*
size_per_head_
;
v_buf_2_
=
k_buf_2_
+
local_kv_head_num_
*
bsz
*
max_q
*
size_per_head_
;
v_buf_2_
=
k_buf_2_
+
local_kv_head_num_
*
bsz
*
max_q
*
size_per_head_
;
if
(
use_fmha_
)
{
// if (use_fmha_) {
if
(
0
)
{
FlashAttentionOp
<
T
>
flash_attention
(
bsz
,
local_head_num_
,
max_k
,
max_q
,
size_per_head_
);
FlashAttentionOp
<
T
>
flash_attention
(
bsz
,
local_head_num_
,
max_k
,
max_q
,
size_per_head_
);
if
(
flash_attention
.
get_workspace_size
()
>
0
)
{
if
(
flash_attention
.
get_workspace_size
()
>
0
)
{
qk_buf_float_
=
qk_buf_float_
=
...
@@ -106,7 +107,7 @@ void UnifiedAttentionLayer<T>::freeBuffer()
...
@@ -106,7 +107,7 @@ void UnifiedAttentionLayer<T>::freeBuffer()
allocator_
->
free
((
void
**
)(
&
q_buf_2_
));
allocator_
->
free
((
void
**
)(
&
q_buf_2_
));
allocator_
->
free
((
void
**
)(
&
qkv_buf_3_
));
allocator_
->
free
((
void
**
)(
&
qkv_buf_3_
));
allocator_
->
free
((
void
**
)
&
qk_buf_float_
);
//
allocator_->free((void**)&qk_buf_float_);
allocator_
->
free
((
void
**
)(
&
k_cache_buf_
));
allocator_
->
free
((
void
**
)(
&
k_cache_buf_
));
allocator_
->
free
((
void
**
)(
&
qk_buf_
));
allocator_
->
free
((
void
**
)(
&
qk_buf_
));
allocator_
->
free
((
void
**
)(
&
qkv_buf_2_
));
allocator_
->
free
((
void
**
)(
&
qkv_buf_2_
));
...
@@ -346,7 +347,8 @@ void UnifiedAttentionLayer<T>::prefill(T* output,
...
@@ -346,7 +347,8 @@ void UnifiedAttentionLayer<T>::prefill(T* output,
stream_
);
stream_
);
sync_check_cuda_error
();
sync_check_cuda_error
();
if
(
use_fmha_
)
{
// if (use_fmha_) {
if
(
0
)
{
fusedMultiHeadAttention
(
output
,
fusedMultiHeadAttention
(
output
,
q_buf_2_
,
q_buf_2_
,
tmp_k_ptrs
,
tmp_k_ptrs
,
...
@@ -456,66 +458,66 @@ void UnifiedAttentionLayer<T>::decode(T* output,
...
@@ -456,66 +458,66 @@ void UnifiedAttentionLayer<T>::decode(T* output,
}
}
}
}
template
<
typename
T
>
//
template<typename T>
void
UnifiedAttentionLayer
<
T
>::
fusedMultiHeadAttention
(
T
*
output
,
//
void UnifiedAttentionLayer<T>::fusedMultiHeadAttention(T* output,
const
T
*
query
,
//
const T* query,
T
**
key_cache_ptrs
,
//
T** key_cache_ptrs,
T
**
val_cache_ptrs
,
//
T** val_cache_ptrs,
size_t
cache_layer_offset
,
//
size_t cache_layer_offset,
T
*
attention_mask
,
//
T* attention_mask,
int
*
cu_seqlens
,
//
int* cu_seqlens,
int
*
context_lengths
,
//
int* context_lengths,
int
batch_size
,
//
int batch_size,
int
max_q_len
,
//
int max_q_len,
int
max_k_len
,
//
int max_k_len,
int
max_seq_len
)
//
int max_seq_len)
{
//
{
//////////////////////////////////////////////
//
//////////////////////////////////////////////
// flash attention
//
// flash attention
// flash attention 2 only support half inputs
//
// flash attention 2 only support half inputs
using
AttentionOp
=
FlashAttentionOp
<
T
>
;
//
using AttentionOp = FlashAttentionOp<T>;
using
Layout
=
typename
AttentionOp
::
AttentionLayout
;
//
using Layout = typename AttentionOp::AttentionLayout;
Layout
layout_q
{
//
Layout layout_q{
int
(
local_head_num_
*
max_q_len
*
size_per_head_
),
int
(
size_per_head_
),
int
(
max_q_len
*
size_per_head_
)};
//
int(local_head_num_ * max_q_len * size_per_head_), int(size_per_head_), int(max_q_len * size_per_head_)};
Layout
layout_k
{
int
(
local_head_num_
*
max_seq_len
*
size_per_head_
),
//
Layout layout_k{int(local_head_num_ * max_seq_len * size_per_head_),
int
(
size_per_head_
),
//
int(size_per_head_),
int
(
max_seq_len
*
size_per_head_
),
//
int(max_seq_len * size_per_head_),
false
,
//
false,
cache_layer_offset
,
//
cache_layer_offset,
key_cache_ptrs
};
//
key_cache_ptrs};
Layout
layout_v
{
int
(
local_head_num_
*
max_seq_len
*
size_per_head_
),
//
Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
int
(
size_per_head_
),
//
int(size_per_head_),
int
(
max_seq_len
*
size_per_head_
),
//
int(max_seq_len * size_per_head_),
false
,
//
false,
cache_layer_offset
,
//
cache_layer_offset,
val_cache_ptrs
};
//
val_cache_ptrs};
Layout
layout_o
{
//
Layout layout_o{
int
(
local_head_num_
*
max_q_len
*
size_per_head_
),
//
int(local_head_num_ * max_q_len * size_per_head_),
int
(
local_head_num_
*
size_per_head_
),
//
int(local_head_num_ * size_per_head_),
int
(
size_per_head_
),
//
int(size_per_head_),
true
,
//
true,
};
//
};
size_t
group_size
=
size_t
(
local_head_num_
/
local_kv_head_num_
);
//
size_t group_size = size_t(local_head_num_ / local_kv_head_num_);
AttentionOp
flash_attention
(
batch_size
,
local_head_num_
,
max_k_len
,
max_q_len
,
size_per_head_
);
//
AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
typename
AttentionOp
::
Params
attn_params
{
output
,
//
typename AttentionOp::Params attn_params{output,
(
T
*
)
query
,
//
(T*)query,
k_cache_buf_
,
//
k_cache_buf_,
v_cache_buf_
,
//
v_cache_buf_,
attention_mask
,
//
attention_mask,
qk_buf_float_
,
//
qk_buf_float_,
cu_seqlens
,
//
cu_seqlens,
nullptr
,
//
nullptr,
nullptr
,
//
nullptr,
context_lengths
,
//
context_lengths,
group_size
,
//
group_size,
layout_q
,
//
layout_q,
layout_k
,
//
layout_k,
layout_v
,
//
layout_v,
layout_o
};
//
layout_o};
//
//
//
flash_attention
(
attn_params
,
stream_
);
//
flash_attention(attn_params, stream_);
}
//
}
template
<
typename
T
>
template
<
typename
T
>
void
UnifiedAttentionLayer
<
T
>::
unfusedMultiHeadAttention
(
T
*
output
,
void
UnifiedAttentionLayer
<
T
>::
unfusedMultiHeadAttention
(
T
*
output
,
...
...
src/turbomind/triton_backend/llama/LlamaTritonModel.cc
View file @
3253240a
...
@@ -47,6 +47,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
...
@@ -47,6 +47,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
reader
.
GetInteger
(
"ft_instance_hyperparameter"
,
"enable_custom_all_reduce"
,
0
),
reader
.
GetInteger
(
"ft_instance_hyperparameter"
,
"enable_custom_all_reduce"
,
0
),
model_dir
);
model_dir
);
}
}
#ifdef ENABLE_BF16
else
if
(
data_type
==
"bf16"
)
{
else
if
(
data_type
==
"bf16"
)
{
#ifdef ENABLE_BF16
#ifdef ENABLE_BF16
return
std
::
make_shared
<
LlamaTritonModel
<
__nv_bfloat16
>>
(
return
std
::
make_shared
<
LlamaTritonModel
<
__nv_bfloat16
>>
(
...
@@ -59,6 +60,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
...
@@ -59,6 +60,7 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
ft
::
FT_CHECK
(
false
);
ft
::
FT_CHECK
(
false
);
#endif
#endif
}
}
#endif
else
{
else
{
return
std
::
make_shared
<
LlamaTritonModel
<
float
>>
(
return
std
::
make_shared
<
LlamaTritonModel
<
float
>>
(
reader
.
GetInteger
(
"ft_instance_hyperparameter"
,
"tensor_para_size"
),
reader
.
GetInteger
(
"ft_instance_hyperparameter"
,
"tensor_para_size"
),
...
@@ -177,7 +179,8 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
...
@@ -177,7 +179,8 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
norm_eps_
=
reader
.
GetFloat
(
"llama"
,
"norm_eps"
);
norm_eps_
=
reader
.
GetFloat
(
"llama"
,
"norm_eps"
);
start_id_
=
reader
.
GetInteger
(
"llama"
,
"start_id"
);
start_id_
=
reader
.
GetInteger
(
"llama"
,
"start_id"
);
end_id_
=
reader
.
GetInteger
(
"llama"
,
"end_id"
);
end_id_
=
reader
.
GetInteger
(
"llama"
,
"end_id"
);
use_context_fmha_
=
reader
.
GetInteger
(
"llama"
,
"use_context_fmha"
,
1
);
// use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
use_context_fmha_
=
0
;
cache_block_seq_len_
=
reader
.
GetInteger
(
"llama"
,
"cache_block_seq_len"
,
0
);
cache_block_seq_len_
=
reader
.
GetInteger
(
"llama"
,
"cache_block_seq_len"
,
0
);
attn_bias_
=
reader
.
GetInteger
(
"llama"
,
"attn_bias"
,
0
);
attn_bias_
=
reader
.
GetInteger
(
"llama"
,
"attn_bias"
,
0
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment