Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
# Copyright (c) MegFlow. All rights reserved.
# /bin/python3
import argparse
import os
import re
def make_parser():
parser = argparse.ArgumentParser('Doc link checker')
parser.add_argument('--http',
default=False,
type=bool,
help='check http or not ')
parser.add_argument('--target',
default='./docs',
type=str,
help='the directory or file to check')
return parser
pattern = re.compile(r'\[.*?\]\(.*?\)')
def analyze_doc(home, path):
print('analyze {}'.format(path))
problem_list = []
code_block = 0
with open(path) as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line.startswith('```'):
code_block = 1 - code_block
if code_block > 0:
continue
if '[' in line and ']' in line and '(' in line and ')' in line:
all = pattern.findall(line)
for item in all:
# skip ![]()
if item.find('[') == item.find(']') - 1:
continue
# process the case [text()]()
offset = item.find('](')
if offset == -1:
continue
item = item[offset:]
start = item.find('(')
end = item.find(')')
ref = item[start + 1:end]
if ref.startswith('http') or ref.startswith('#'):
continue
if '.md#' in ref:
ref = ref[ref.find('#'):]
fullpath = os.path.join(home, ref)
if not os.path.exists(fullpath):
problem_list.append(ref)
else:
continue
if len(problem_list) > 0:
print(f'{path}:')
for item in problem_list:
print(f'\t {item}')
print('\n')
raise Exception('found link error')
def traverse(target):
if os.path.isfile(target):
analyze_doc(os.path.dirname(target), target)
return
for home, dirs, files in os.walk(target):
for filename in files:
if filename.endswith('.md'):
path = os.path.join(home, filename)
if os.path.islink(path) is False:
analyze_doc(home, path)
if __name__ == '__main__':
args = make_parser().parse_args()
traverse(args.target)
name: lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install pre-commit hook
run: |
python -m pip install pre-commit
pre-commit install
- name: Linting
run: pre-commit run --all-files
- name: Format c/cuda codes with clang-format
uses: DoozyX/clang-format-lint-action@v0.14
with:
source: src
extensions: h,c,cpp,hpp,cu,cuh
clangFormatVersion: 14
style: file
- name: Check markdown link
uses: gaurav-nelson/github-action-markdown-link-check@v1
with:
use-quiet-mode: 'yes'
use-verbose-mode: 'yes'
# check-modified-files-only: 'yes'
config-file: '.github/md-link-config.json'
file-path: './README.md, ./LICENSE, ./README_zh-CN.md'
- name: Check doc link
run: |
python .github/scripts/doc_link_checker.py --target README_zh-CN.md
python .github/scripts/doc_link_checker.py --target README.md
- name: Check docstring coverage
run: |
python -m pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy
- name: Check pylint score
run: |
python -m pip install pylint
pylint lmdeploy
# Copyright (c) OpenMMLab. All rights reserved.
import subprocess
import fire
......@@ -12,7 +13,8 @@ def main(head_num: int = 32,
max_batch_size: int = 64):
for bsz in range(1, max_batch_size + 1):
subprocess.call(
f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size} {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size}'
f' {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
shell=True)
......
......@@ -78,5 +78,3 @@ rotary_embedding=128
start_id=1
end_id=2
inter_size=22016
from typing import List
import fire
......
......@@ -17,10 +17,10 @@
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
......@@ -81,7 +81,8 @@ namespace mmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh>
struct Qk_vec_m_ {};
struct Qk_vec_m_ {
};
template<>
struct Qk_vec_m_<float, 32> {
......@@ -181,7 +182,8 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct K_vec_m_ {};
struct K_vec_m_ {
};
template<>
struct K_vec_m_<float, 4> {
......@@ -262,7 +264,8 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int V_VEC_SIZE>
struct V_vec_m_ {};
struct V_vec_m_ {
};
template<>
struct V_vec_m_<float, 1> {
......@@ -342,7 +345,8 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> {
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T>
struct Qk_vec_acum_fp32_ {};
struct Qk_vec_acum_fp32_ {
};
template<>
struct Qk_vec_acum_fp32_<float> {
......@@ -424,7 +428,8 @@ struct Qk_vec_acum_fp32_<fp8_4_t> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct K_vec_acum_fp32_ {};
struct K_vec_acum_fp32_ {
};
template<>
struct K_vec_acum_fp32_<float> {
......@@ -486,7 +491,8 @@ struct K_vec_acum_fp32_<fp8_4_t> {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template<typename T>
struct V_vec_acum_fp32_ {};
struct V_vec_acum_fp32_ {
};
template<>
struct V_vec_acum_fp32_<float> {
......@@ -1462,7 +1468,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
}
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
......@@ -1483,11 +1490,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
}
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
}
}
......@@ -1573,21 +1580,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
} else {
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
}
} else {
T* k_cache =
params.k_cache_per_sample ?
(params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) :
}
else {
T* k_cache = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
k_cache_batch = k_cache;
}
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
......@@ -1626,12 +1633,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type;
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(
&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale);
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
} else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>((*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
}
}
......@@ -1757,17 +1767,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
} else {
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
}
v_cache_batch_int8 = v_cache_int8;
} else {
}
else {
v_cache =
params.v_cache_per_sample ?
(params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) :
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
......@@ -1824,12 +1835,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
}
else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
}
// Load the logits from shared memory.
......@@ -1869,12 +1883,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
}
else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
}
// Load the logits from shared memory.
......@@ -1925,7 +1942,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
} else {
}
else {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
}
}
......@@ -1994,7 +2012,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]),
mul<V_vec_acum, float, V_vec_acum>(result_scale, out));
#endif // FP8_MHA
} else {
}
else {
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out);
}
#else // MMHA_USE_FP32_ACUM_FOR_OUT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment