Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
# Copyright (c) MegFlow. All rights reserved.
# /bin/python3
import argparse
import os
import re
def make_parser():
parser = argparse.ArgumentParser('Doc link checker')
parser.add_argument('--http',
default=False,
type=bool,
help='check http or not ')
parser.add_argument('--target',
default='./docs',
type=str,
help='the directory or file to check')
return parser
pattern = re.compile(r'\[.*?\]\(.*?\)')
def analyze_doc(home, path):
print('analyze {}'.format(path))
problem_list = []
code_block = 0
with open(path) as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line.startswith('```'):
code_block = 1 - code_block
if code_block > 0:
continue
if '[' in line and ']' in line and '(' in line and ')' in line:
all = pattern.findall(line)
for item in all:
# skip ![]()
if item.find('[') == item.find(']') - 1:
continue
# process the case [text()]()
offset = item.find('](')
if offset == -1:
continue
item = item[offset:]
start = item.find('(')
end = item.find(')')
ref = item[start + 1:end]
if ref.startswith('http') or ref.startswith('#'):
continue
if '.md#' in ref:
ref = ref[ref.find('#'):]
fullpath = os.path.join(home, ref)
if not os.path.exists(fullpath):
problem_list.append(ref)
else:
continue
if len(problem_list) > 0:
print(f'{path}:')
for item in problem_list:
print(f'\t {item}')
print('\n')
raise Exception('found link error')
def traverse(target):
if os.path.isfile(target):
analyze_doc(os.path.dirname(target), target)
return
for home, dirs, files in os.walk(target):
for filename in files:
if filename.endswith('.md'):
path = os.path.join(home, filename)
if os.path.islink(path) is False:
analyze_doc(home, path)
if __name__ == '__main__':
args = make_parser().parse_args()
traverse(args.target)
name: lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install pre-commit hook
run: |
python -m pip install pre-commit
pre-commit install
- name: Linting
run: pre-commit run --all-files
- name: Format c/cuda codes with clang-format
uses: DoozyX/clang-format-lint-action@v0.14
with:
source: src
extensions: h,c,cpp,hpp,cu,cuh
clangFormatVersion: 14
style: file
- name: Check markdown link
uses: gaurav-nelson/github-action-markdown-link-check@v1
with:
use-quiet-mode: 'yes'
use-verbose-mode: 'yes'
# check-modified-files-only: 'yes'
config-file: '.github/md-link-config.json'
file-path: './README.md, ./LICENSE, ./README_zh-CN.md'
- name: Check doc link
run: |
python .github/scripts/doc_link_checker.py --target README_zh-CN.md
python .github/scripts/doc_link_checker.py --target README.md
- name: Check docstring coverage
run: |
python -m pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy
- name: Check pylint score
run: |
python -m pip install pylint
pylint lmdeploy
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
inih is released under the New BSD license (see LICENSE.txt). Go to the project inih is released under the New BSD license (see LICENSE.txt). Go to the project
home page for more info: home page for more info:
https://github.com/benhoyt/inih https://github.com/benhoyt/inih
https://github.com/jtilly/inih https://github.com/jtilly/inih
*/ */
#ifndef __INI_H__ #ifndef __INI_H__
...@@ -344,7 +344,7 @@ public: ...@@ -344,7 +344,7 @@ public:
// according to strtof(). // according to strtof().
float GetFloat(std::string section, std::string name, float default_value) const; float GetFloat(std::string section, std::string name, float default_value) const;
float GetFloat(std::string section, std::string name) const; float GetFloat(std::string section, std::string name) const;
// Get a boolean value from INI file, returning default_value if not found or if // Get a boolean value from INI file, returning default_value if not found or if
// not a valid true/false value. Valid true values are "true", "yes", "on", "1", // not a valid true/false value. Valid true values are "true", "yes", "on", "1",
// and valid false values are "false", "no", "off", "0" (not case sensitive). // and valid false values are "false", "no", "off", "0" (not case sensitive).
...@@ -498,4 +498,4 @@ inline int INIReader::ValueHandler(void* user, const char* section, const char* ...@@ -498,4 +498,4 @@ inline int INIReader::ValueHandler(void* user, const char* section, const char*
return 1; return 1;
} }
#endif // __INIREADER__ #endif // __INIREADER__
\ No newline at end of file
...@@ -43,7 +43,7 @@ endif() ...@@ -43,7 +43,7 @@ endif()
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088 GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088
) )
......
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
# #
# From PyTorch: # From PyTorch:
# #
# Copyright (c) 2016- Facebook, Inc (Adam Paszke) # Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) # Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
...@@ -11,57 +11,57 @@ ...@@ -11,57 +11,57 @@
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) # Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
# #
# From Caffe2: # From Caffe2:
# #
# Copyright (c) 2016-present, Facebook Inc. All rights reserved. # Copyright (c) 2016-present, Facebook Inc. All rights reserved.
# #
# All contributions by Facebook: # All contributions by Facebook:
# Copyright (c) 2016 Facebook Inc. # Copyright (c) 2016 Facebook Inc.
# #
# All contributions by Google: # All contributions by Google:
# Copyright (c) 2015 Google Inc. # Copyright (c) 2015 Google Inc.
# All rights reserved. # All rights reserved.
# #
# All contributions by Yangqing Jia: # All contributions by Yangqing Jia:
# Copyright (c) 2015 Yangqing Jia # Copyright (c) 2015 Yangqing Jia
# All rights reserved. # All rights reserved.
# #
# All contributions by Kakao Brain: # All contributions by Kakao Brain:
# Copyright 2019-2020 Kakao Brain # Copyright 2019-2020 Kakao Brain
# #
# All contributions from Caffe: # All contributions from Caffe:
# Copyright(c) 2013, 2014, 2015, the respective contributors # Copyright(c) 2013, 2014, 2015, the respective contributors
# All rights reserved. # All rights reserved.
# #
# All other contributions: # All other contributions:
# Copyright(c) 2015, 2016 the respective contributors # Copyright(c) 2015, 2016 the respective contributors
# All rights reserved. # All rights reserved.
# #
# Caffe2 uses a copyright model similar to Caffe: each contributor holds # Caffe2 uses a copyright model similar to Caffe: each contributor holds
# copyright over their contributions to Caffe2. The project versioning records # copyright over their contributions to Caffe2. The project versioning records
# all such contribution and copyright details. If a contributor wants to further # all such contribution and copyright details. If a contributor wants to further
# mark their specific copyright on a particular contribution, they should # mark their specific copyright on a particular contribution, they should
# indicate their copyright solely in the commit message of the change when it is # indicate their copyright solely in the commit message of the change when it is
# committed. # committed.
# #
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
# 1. Redistributions of source code must retain the above copyright # 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer. # notice, this list of conditions and the following disclaimer.
# #
# 2. Redistributions in binary form must reproduce the above copyright # 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the # notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution. # documentation and/or other materials provided with the distribution.
# #
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America # 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
# and IDIAP Research Institute nor the names of its contributors may be # and IDIAP Research Institute nor the names of its contributors may be
# used to endorse or promote products derived from this software without # used to endorse or promote products derived from this software without
# specific prior written permission. # specific prior written permission.
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
...@@ -73,7 +73,7 @@ ...@@ -73,7 +73,7 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
# #
# Find the nccl libraries # Find the nccl libraries
# #
# The following variables are optionally searched for defaults # The following variables are optionally searched for defaults
......
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
add_subdirectory(cpp) add_subdirectory(cpp)
\ No newline at end of file
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
add_subdirectory(llama) add_subdirectory(llama)
\ No newline at end of file
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
add_executable(llama_triton_example llama_triton_example.cc) add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list glog) nvtx_utils word_list glog)
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import subprocess import subprocess
import fire import fire
...@@ -12,7 +13,8 @@ def main(head_num: int = 32, ...@@ -12,7 +13,8 @@ def main(head_num: int = 32,
max_batch_size: int = 64): max_batch_size: int = 64):
for bsz in range(1, max_batch_size + 1): for bsz in range(1, max_batch_size + 1):
subprocess.call( subprocess.call(
f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size} {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}', f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size}'
f' {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
shell=True) shell=True)
......
...@@ -78,5 +78,3 @@ rotary_embedding=128 ...@@ -78,5 +78,3 @@ rotary_embedding=128
start_id=1 start_id=1
end_id=2 end_id=2
inter_size=22016 inter_size=22016
...@@ -5,4 +5,4 @@ ...@@ -5,4 +5,4 @@
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,36589,3467,7849,299,7032,46323,13,44975,45004,11130,32843,45004,35597 0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,36589,3467,7849,299,7032,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44976,39798,6828,3467,46323,13,44975,45004,11130,32843,45004,35597 0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44976,39798,6828,3467,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,2795,977,9193,299,405,537,46323,13,44975,45004,11130,32843,45004,35597 0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,2795,977,9193,299,405,537,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,45691,45926,45513,46641,47641,46285,6456,46323,13,44975,45004,11130,32843,45004,35597 0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,45691,45926,45513,46641,47641,46285,6456,46323,13,44975,45004,11130,32843,45004,35597
\ No newline at end of file
from typing import List from typing import List
import fire import fire
......
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
add_subdirectory(fastertransformer) add_subdirectory(fastertransformer)
\ No newline at end of file
...@@ -19,4 +19,4 @@ add_subdirectory(models) ...@@ -19,4 +19,4 @@ add_subdirectory(models)
if(BUILD_PYT) if(BUILD_PYT)
add_subdirectory(th_op) add_subdirectory(th_op)
endif() endif()
add_subdirectory(triton_backend) add_subdirectory(triton_backend)
\ No newline at end of file
...@@ -467,4 +467,4 @@ template void invokeQuantizeMatrixRebuildPadding<half, __nv_fp8_e4m3, QUANTIZE_M ...@@ -467,4 +467,4 @@ template void invokeQuantizeMatrixRebuildPadding<half, __nv_fp8_e4m3, QUANTIZE_M
#endif #endif
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -395,4 +395,4 @@ template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<_ ...@@ -395,4 +395,4 @@ template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<_
cudaStream_t stream); cudaStream_t stream);
#endif #endif
template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams<uint32_t>& param, cudaStream_t stream); template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams<uint32_t>& param, cudaStream_t stream);
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -60,4 +60,4 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams<T>& param, cudaStream_t s ...@@ -60,4 +60,4 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams<T>& param, cudaStream_t s
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo); void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo);
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -116,8 +116,8 @@ struct Multihead_attention_params_base { ...@@ -116,8 +116,8 @@ struct Multihead_attention_params_base {
const float* qkv_scale_out = nullptr; const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr; const float* attention_out_scale = nullptr;
int int8_mode = 0; int int8_mode = 0;
float attention_k_scale = 0.f; float attention_k_scale = 0.f;
float attention_v_scale = 0.f; float attention_v_scale = 0.f;
}; };
template<typename T> template<typename T>
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh" #include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include <assert.h> #include <assert.h>
#include <float.h> #include <float.h>
#include <type_traits> #include <type_traits>
...@@ -81,7 +81,8 @@ namespace mmha { ...@@ -81,7 +81,8 @@ namespace mmha {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh> template<typename T, int Dh>
struct Qk_vec_m_ {}; struct Qk_vec_m_ {
};
template<> template<>
struct Qk_vec_m_<float, 32> { struct Qk_vec_m_<float, 32> {
...@@ -181,7 +182,8 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> { ...@@ -181,7 +182,8 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY> template<typename T, int THREADS_PER_KEY>
struct K_vec_m_ {}; struct K_vec_m_ {
};
template<> template<>
struct K_vec_m_<float, 4> { struct K_vec_m_<float, 4> {
...@@ -262,7 +264,8 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> { ...@@ -262,7 +264,8 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int V_VEC_SIZE> template<typename T, int V_VEC_SIZE>
struct V_vec_m_ {}; struct V_vec_m_ {
};
template<> template<>
struct V_vec_m_<float, 1> { struct V_vec_m_<float, 1> {
...@@ -342,7 +345,8 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> { ...@@ -342,7 +345,8 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> {
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T> template<typename T>
struct Qk_vec_acum_fp32_ {}; struct Qk_vec_acum_fp32_ {
};
template<> template<>
struct Qk_vec_acum_fp32_<float> { struct Qk_vec_acum_fp32_<float> {
...@@ -424,7 +428,8 @@ struct Qk_vec_acum_fp32_<fp8_4_t> { ...@@ -424,7 +428,8 @@ struct Qk_vec_acum_fp32_<fp8_4_t> {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct K_vec_acum_fp32_ {}; struct K_vec_acum_fp32_ {
};
template<> template<>
struct K_vec_acum_fp32_<float> { struct K_vec_acum_fp32_<float> {
...@@ -486,7 +491,8 @@ struct K_vec_acum_fp32_<fp8_4_t> { ...@@ -486,7 +491,8 @@ struct K_vec_acum_fp32_<fp8_4_t> {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template<typename T> template<typename T>
struct V_vec_acum_fp32_ {}; struct V_vec_acum_fp32_ {
};
template<> template<>
struct V_vec_acum_fp32_<float> { struct V_vec_acum_fp32_<float> {
...@@ -1455,14 +1461,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1455,14 +1461,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci; + tlength_circ * QK_ELTS_IN_16B + ci;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache); int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else { }
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k); *reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} }
} }
...@@ -1483,11 +1490,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1483,11 +1490,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else { }
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) = *reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k); vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} }
} }
} }
} }
...@@ -1565,29 +1572,29 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1565,29 +1572,29 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer. // The base pointer for the key in the cache buffer.
T* k_cache_batch = nullptr; T* k_cache_batch = nullptr;
int8_t* k_cache_batch_int8 = nullptr; int8_t* k_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
// convert k_cache_per_sample to int8 // convert k_cache_per_sample to int8
if (params.k_cache_per_sample) { if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki; k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
} else { }
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache); else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki]; k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
} }
} else { }
T* k_cache = else {
params.k_cache_per_sample ? T* k_cache = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
(params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) : + hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki]; &params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki]; // T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
k_cache_batch = k_cache; k_cache_batch = k_cache;
} }
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
...@@ -1626,12 +1633,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1626,12 +1633,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type; using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type;
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]); Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(
&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale); Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale);
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float); k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
} else { }
k[ii] = vec_conversion<K_vec_k, K_vec_m>((*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
} }
} }
} }
...@@ -1747,28 +1757,29 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1747,28 +1757,29 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer. // The base pointer for the value in the cache buffer.
T* v_cache = nullptr; T* v_cache = nullptr;
T* v_cache_batch = nullptr; T* v_cache_batch = nullptr;
int8_t* v_cache_int8 = nullptr; int8_t* v_cache_int8 = nullptr;
int8_t* v_cache_batch_int8 = nullptr; int8_t* v_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
if (params.v_cache_per_sample) { if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
} else { }
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache); else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi]; v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
} }
v_cache_batch_int8 = v_cache_int8; v_cache_batch_int8 = v_cache_int8;
} else { }
else {
v_cache = v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
params.v_cache_per_sample ? + hi * params.memory_max_len * Dh + vi) :
(params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) : &params.v_cache[bhi * params.memory_max_len * Dh + vi];
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; // T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache; v_cache_batch = v_cache;
...@@ -1822,14 +1833,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1822,14 +1833,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]); Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else { }
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh])); else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
} }
// Load the logits from shared memory. // Load the logits from shared memory.
...@@ -1867,14 +1881,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1867,14 +1881,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]); Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else { }
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh])); else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
} }
// Load the logits from shared memory. // Load the logits from shared memory.
...@@ -1910,7 +1927,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1910,7 +1927,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the loads from the V buffer. // Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi; const auto v_offset = qkv_base_offset + vi;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset])); v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
// Trigger the loads from the V bias buffer. // Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]); // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
...@@ -1925,7 +1942,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1925,7 +1942,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale); Packed_Int8_t v_int8 = quant(v, v_scale);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8; *reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
} else { }
else {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v); *reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
} }
} }
...@@ -1994,7 +2012,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1994,7 +2012,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]),
mul<V_vec_acum, float, V_vec_acum>(result_scale, out)); mul<V_vec_acum, float, V_vec_acum>(result_scale, out));
#endif // FP8_MHA #endif // FP8_MHA
} else { }
else {
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out); convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out);
} }
#else // MMHA_USE_FP32_ACUM_FOR_OUT #else // MMHA_USE_FP32_ACUM_FOR_OUT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment