"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "c6aa379de2251e64bab47a2704266abf150080a8"
Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
# Copyright (c) MegFlow. All rights reserved.
# /bin/python3
import argparse
import os
import re
def make_parser():
parser = argparse.ArgumentParser('Doc link checker')
parser.add_argument('--http',
default=False,
type=bool,
help='check http or not ')
parser.add_argument('--target',
default='./docs',
type=str,
help='the directory or file to check')
return parser
pattern = re.compile(r'\[.*?\]\(.*?\)')
def analyze_doc(home, path):
print('analyze {}'.format(path))
problem_list = []
code_block = 0
with open(path) as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line.startswith('```'):
code_block = 1 - code_block
if code_block > 0:
continue
if '[' in line and ']' in line and '(' in line and ')' in line:
all = pattern.findall(line)
for item in all:
# skip ![]()
if item.find('[') == item.find(']') - 1:
continue
# process the case [text()]()
offset = item.find('](')
if offset == -1:
continue
item = item[offset:]
start = item.find('(')
end = item.find(')')
ref = item[start + 1:end]
if ref.startswith('http') or ref.startswith('#'):
continue
if '.md#' in ref:
ref = ref[ref.find('#'):]
fullpath = os.path.join(home, ref)
if not os.path.exists(fullpath):
problem_list.append(ref)
else:
continue
if len(problem_list) > 0:
print(f'{path}:')
for item in problem_list:
print(f'\t {item}')
print('\n')
raise Exception('found link error')
def traverse(target):
if os.path.isfile(target):
analyze_doc(os.path.dirname(target), target)
return
for home, dirs, files in os.walk(target):
for filename in files:
if filename.endswith('.md'):
path = os.path.join(home, filename)
if os.path.islink(path) is False:
analyze_doc(home, path)
if __name__ == '__main__':
args = make_parser().parse_args()
traverse(args.target)
name: lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install pre-commit hook
run: |
python -m pip install pre-commit
pre-commit install
- name: Linting
run: pre-commit run --all-files
- name: Format c/cuda codes with clang-format
uses: DoozyX/clang-format-lint-action@v0.14
with:
source: src
extensions: h,c,cpp,hpp,cu,cuh
clangFormatVersion: 14
style: file
- name: Check markdown link
uses: gaurav-nelson/github-action-markdown-link-check@v1
with:
use-quiet-mode: 'yes'
use-verbose-mode: 'yes'
# check-modified-files-only: 'yes'
config-file: '.github/md-link-config.json'
file-path: './README.md, ./LICENSE, ./README_zh-CN.md'
- name: Check doc link
run: |
python .github/scripts/doc_link_checker.py --target README_zh-CN.md
python .github/scripts/doc_link_checker.py --target README.md
- name: Check docstring coverage
run: |
python -m pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy
- name: Check pylint score
run: |
python -m pip install pylint
pylint lmdeploy
......@@ -9,7 +9,7 @@
inih is released under the New BSD license (see LICENSE.txt). Go to the project
home page for more info:
https://github.com/benhoyt/inih
https://github.com/jtilly/inih
https://github.com/jtilly/inih
*/
#ifndef __INI_H__
......@@ -344,7 +344,7 @@ public:
// according to strtof().
float GetFloat(std::string section, std::string name, float default_value) const;
float GetFloat(std::string section, std::string name) const;
// Get a boolean value from INI file, returning default_value if not found or if
// not a valid true/false value. Valid true values are "true", "yes", "on", "1",
// and valid false values are "false", "no", "off", "0" (not case sensitive).
......@@ -498,4 +498,4 @@ inline int INIReader::ValueHandler(void* user, const char* section, const char*
return 1;
}
#endif // __INIREADER__
\ No newline at end of file
#endif // __INIREADER__
......@@ -43,7 +43,7 @@ endif()
include(FetchContent)
FetchContent_Declare(
repo-cutlass
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088
)
......
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
#
# From PyTorch:
#
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
......@@ -11,57 +11,57 @@
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
#
# From Caffe2:
#
#
# Copyright (c) 2016-present, Facebook Inc. All rights reserved.
#
#
# All contributions by Facebook:
# Copyright (c) 2016 Facebook Inc.
#
#
# All contributions by Google:
# Copyright (c) 2015 Google Inc.
# All rights reserved.
#
#
# All contributions by Yangqing Jia:
# Copyright (c) 2015 Yangqing Jia
# All rights reserved.
#
#
# All contributions by Kakao Brain:
# Copyright 2019-2020 Kakao Brain
#
#
# All contributions from Caffe:
# Copyright(c) 2013, 2014, 2015, the respective contributors
# All rights reserved.
#
#
# All other contributions:
# Copyright(c) 2015, 2016 the respective contributors
# All rights reserved.
#
#
# Caffe2 uses a copyright model similar to Caffe: each contributor holds
# copyright over their contributions to Caffe2. The project versioning records
# all such contribution and copyright details. If a contributor wants to further
# mark their specific copyright on a particular contribution, they should
# indicate their copyright solely in the commit message of the change when it is
# committed.
#
#
# All rights reserved.
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
#
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
# and IDIAP Research Institute nor the names of its contributors may be
# used to endorse or promote products derived from this software without
# specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
......@@ -73,7 +73,7 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
#
# Find the nccl libraries
#
# The following variables are optionally searched for defaults
......
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
add_subdirectory(cpp)
\ No newline at end of file
add_subdirectory(cpp)
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
add_subdirectory(llama)
\ No newline at end of file
add_subdirectory(llama)
......@@ -2,5 +2,5 @@
add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list glog)
\ No newline at end of file
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list glog)
# Copyright (c) OpenMMLab. All rights reserved.
import subprocess
import fire
......@@ -12,7 +13,8 @@ def main(head_num: int = 32,
max_batch_size: int = 64):
for bsz in range(1, max_batch_size + 1):
subprocess.call(
f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size} {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size}'
f' {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
shell=True)
......
......@@ -78,5 +78,3 @@ rotary_embedding=128
start_id=1
end_id=2
inter_size=22016
......@@ -5,4 +5,4 @@
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,36589,3467,7849,299,7032,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44976,39798,6828,3467,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,2795,977,9193,299,405,537,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,45691,45926,45513,46641,47641,46285,6456,46323,13,44975,45004,11130,32843,45004,35597
\ No newline at end of file
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,45691,45926,45513,46641,47641,46285,6456,46323,13,44975,45004,11130,32843,45004,35597
from typing import List
import fire
......
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
add_subdirectory(fastertransformer)
\ No newline at end of file
add_subdirectory(fastertransformer)
......@@ -19,4 +19,4 @@ add_subdirectory(models)
if(BUILD_PYT)
add_subdirectory(th_op)
endif()
add_subdirectory(triton_backend)
\ No newline at end of file
add_subdirectory(triton_backend)
......@@ -467,4 +467,4 @@ template void invokeQuantizeMatrixRebuildPadding<half, __nv_fp8_e4m3, QUANTIZE_M
#endif
} // namespace fastertransformer
\ No newline at end of file
} // namespace fastertransformer
......@@ -395,4 +395,4 @@ template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<_
cudaStream_t stream);
#endif
template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams<uint32_t>& param, cudaStream_t stream);
} // namespace fastertransformer
\ No newline at end of file
} // namespace fastertransformer
......@@ -60,4 +60,4 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams<T>& param, cudaStream_t s
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo);
} // namespace fastertransformer
\ No newline at end of file
} // namespace fastertransformer
......@@ -116,8 +116,8 @@ struct Multihead_attention_params_base {
const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;
float attention_k_scale = 0.f;
float attention_v_scale = 0.f;
float attention_k_scale = 0.f;
float attention_v_scale = 0.f;
};
template<typename T>
......
......@@ -17,10 +17,10 @@
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
......@@ -81,7 +81,8 @@ namespace mmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh>
struct Qk_vec_m_ {};
struct Qk_vec_m_ {
};
template<>
struct Qk_vec_m_<float, 32> {
......@@ -181,7 +182,8 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct K_vec_m_ {};
struct K_vec_m_ {
};
template<>
struct K_vec_m_<float, 4> {
......@@ -262,7 +264,8 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int V_VEC_SIZE>
struct V_vec_m_ {};
struct V_vec_m_ {
};
template<>
struct V_vec_m_<float, 1> {
......@@ -342,7 +345,8 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> {
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T>
struct Qk_vec_acum_fp32_ {};
struct Qk_vec_acum_fp32_ {
};
template<>
struct Qk_vec_acum_fp32_<float> {
......@@ -424,7 +428,8 @@ struct Qk_vec_acum_fp32_<fp8_4_t> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct K_vec_acum_fp32_ {};
struct K_vec_acum_fp32_ {
};
template<>
struct K_vec_acum_fp32_<float> {
......@@ -486,7 +491,8 @@ struct K_vec_acum_fp32_<fp8_4_t> {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template<typename T>
struct V_vec_acum_fp32_ {};
struct V_vec_acum_fp32_ {
};
template<>
struct V_vec_acum_fp32_<float> {
......@@ -1455,14 +1461,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
}
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
......@@ -1483,11 +1490,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
}
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
}
}
......@@ -1565,29 +1572,29 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer.
T* k_cache_batch = nullptr;
T* k_cache_batch = nullptr;
int8_t* k_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
// convert k_cache_per_sample to int8
if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
} else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
}
} else {
T* k_cache =
params.k_cache_per_sample ?
(params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
}
else {
T* k_cache = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
k_cache_batch = k_cache;
}
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
......@@ -1626,12 +1633,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type;
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(
&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale);
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
} else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>((*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
}
}
......@@ -1747,28 +1757,29 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer.
T* v_cache = nullptr;
T* v_cache = nullptr;
T* v_cache_batch = nullptr;
int8_t* v_cache_int8 = nullptr;
int8_t* v_cache_int8 = nullptr;
int8_t* v_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
} else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
}
v_cache_batch_int8 = v_cache_int8;
} else {
}
else {
v_cache =
params.v_cache_per_sample ?
(params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache;
......@@ -1822,14 +1833,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache.
V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
}
else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
}
// Load the logits from shared memory.
......@@ -1867,14 +1881,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache.
V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
}
else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
}
// Load the logits from shared memory.
......@@ -1910,7 +1927,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
......@@ -1925,7 +1942,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
} else {
}
else {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
}
}
......@@ -1994,7 +2012,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]),
mul<V_vec_acum, float, V_vec_acum>(result_scale, out));
#endif // FP8_MHA
} else {
}
else {
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out);
}
#else // MMHA_USE_FP32_ACUM_FOR_OUT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment