Unverified Commit 2700abb3 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Support attention bias (#14)

* support attention bias

* fix conflict
parent ee962784
...@@ -81,6 +81,8 @@ def export(model_name: str, ...@@ -81,6 +81,8 @@ def export(model_name: str,
param = param.half() param = param.half()
param.contiguous().numpy().tofile(osp.join(out_dir, name)) param.contiguous().numpy().tofile(osp.join(out_dir, name))
attn_bias = False
# reverse the splitting axes since the weights are transposed above # reverse the splitting axes since the weights are transposed above
for param_name, param_data in model_params.items(): for param_name, param_data in model_params.items():
if param_name == 'tok_embeddings.weight': if param_name == 'tok_embeddings.weight':
...@@ -88,13 +90,18 @@ def export(model_name: str, ...@@ -88,13 +90,18 @@ def export(model_name: str,
head_num = dim // size_per_head head_num = dim // size_per_head
split_dim = None split_dim = None
key, ext = param_name.split('.')[-2:] key, ext = param_name.split('.')[-2:]
if key == 'w_qkv' and ext == 'bias':
attn_bias = True
copy = False copy = False
if key in ['w1', 'w3', 'w_qkv']: if key in ['w1', 'w3', 'w_qkv']:
if ext in ['bias']:
copy = True
else:
split_dim = -1 split_dim = -1
if key == 'w1': if key == 'w1':
inter_size = param_data.shape[-1] inter_size = param_data.shape[-1]
elif key in ['w2', 'wo']: elif key in ['w2', 'wo']:
if ext in ['scales', 'zeros']: if ext in ['scales', 'zeros', 'bias']:
copy = True copy = True
else: else:
split_dim = 0 split_dim = 0
...@@ -129,6 +136,7 @@ def export(model_name: str, ...@@ -129,6 +136,7 @@ def export(model_name: str,
rotary_embedding=size_per_head, rotary_embedding=size_per_head,
inter_size=inter_size, inter_size=inter_size,
norm_eps=norm_eps, norm_eps=norm_eps,
attn_bias=attn_bias,
start_id=bos_id, start_id=bos_id,
end_id=eos_id, end_id=eos_id,
weight_type='fp16', weight_type='fp16',
...@@ -189,20 +197,28 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, ...@@ -189,20 +197,28 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
for i, ckpt_path in enumerate(checkpoints): for i, ckpt_path in enumerate(checkpoints):
ckpt = torch.load(ckpt_path, map_location='cpu') ckpt = torch.load(ckpt_path, map_location='cpu')
for param_name, param_data in ckpt.items(): for param_name, param_data in ckpt.items():
key = param_name.split('.')[-2] key, ext = param_name.split('.')[-2:]
# column-parallel # column-parallel
if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'output']: if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'output']:
size = param_data.size(0) size = param_data.size(0)
if ext == 'weight':
param = get_param( param = get_param(
param_name, param_name, [size * n_ckpt, param_data.size(1)])
[size * n_ckpt, param_data.size(1)])
param.data[size * i:size * (i + 1), :] = param_data param.data[size * i:size * (i + 1), :] = param_data
else: # bias
param = get_param(param_name, [size * n_ckpt])
param.data[size * i:size * (i + 1)] = param_data
# row-parallel # row-parallel
elif key in ['w2', 'wo', 'tok_embeddings']: elif key in ['w2', 'wo', 'tok_embeddings']:
size = param_data.size(-1) size = param_data.size(-1)
if ext == 'weight':
param = get_param(param_name, param = get_param(param_name,
[param_data.size(0), size * n_ckpt]) [param_data.size(0), size * n_ckpt])
param.data[:, size * i:size * (i + 1)] = param_data param.data[:, size * i:size * (i + 1)] = param_data
else: # bias
param = get_param(param_name, [size])
param.data = param_data
elif i == 0: elif i == 0:
param = get_param(param_name, param_data.size()) param = get_param(param_name, param_data.size())
param.data = param_data param.data = param_data
...@@ -216,15 +232,18 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, ...@@ -216,15 +232,18 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
param.data = param.data.t() param.data = param.data.t()
# concat qkv projection # concat qkv projection
for t in ['weight', 'bias']:
for i in range(1000): for i in range(1000):
_qkv = [f'layers.{i}.attention.{k}.weight' for k in ['wq', 'wk', 'wv']] _qkv = [f'layers.{i}.attention.{k}.{t}' for k in [
'wq', 'wk', 'wv']]
try: try:
qkv = tuple(map(model_params.pop, _qkv)) qkv = tuple(map(model_params.pop, _qkv))
except KeyError: except KeyError:
break break
qkv = torch.stack(qkv, dim=1) # concat by output_dims
model_params[f'layers.{i}.attention.w_qkv.weight'] = qkv qkv = torch.stack(qkv, dim=qkv[0].dim() - 1)
print(qkv.shape, qkv.dtype) print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape)
model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv
assert num_layer == i, f'miss matched layers: {num_layer} vs {i}' assert num_layer == i, f'miss matched layers: {num_layer} vs {i}'
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc
#include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h" #include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
...@@ -157,7 +158,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -157,7 +158,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
v_buf_2_, v_buf_2_,
PrefixPromptBatchWeightsParam<T>{}, PrefixPromptBatchWeightsParam<T>{},
qkv_buf_, qkv_buf_,
(const T*)nullptr, // qkv_bias weights->qkv.bias,
padding_offset, // padding_offset, padding_offset, // padding_offset,
history_length, // used for applying rotary embedding history_length, // used for applying rotary embedding
batch_size, batch_size,
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc
#include "src/fastertransformer/models/llama/LlamaContextDecoder.h" #include "src/fastertransformer/models/llama/LlamaContextDecoder.h"
#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "src/fastertransformer/kernels/bert_preprocess_kernels.h"
...@@ -243,8 +244,9 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -243,8 +244,9 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
/// self-attention /// self-attention
forwardSelfAttn(sess, input_tensors, layer, false); forwardSelfAttn(sess, input_tensors, layer, false);
invokeFusedAddResidualRMSNorm(decoder_input_output, invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
attn_ffn_io_, attn_ffn_io_,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights, decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_, rmsnorm_eps_,
sess.token_num, sess.token_num,
...@@ -260,8 +262,9 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -260,8 +262,9 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>(); input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddResidualRMSNorm(decoder_input_output, // invokeFusedAddBiasResidualRMSNorm(decoder_input_output, //
attn_ffn_io_, attn_ffn_io_,
decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight, scale_weight,
rmsnorm_eps_, rmsnorm_eps_,
sess.token_num, sess.token_num,
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc
#include "src/fastertransformer/models/llama/LlamaDecoder.h" #include "src/fastertransformer/models/llama/LlamaDecoder.h"
#include "src/fastertransformer/models/llama/llama_decoder_kernels.h" #include "src/fastertransformer/models/llama/llama_decoder_kernels.h"
...@@ -205,8 +206,9 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou ...@@ -205,8 +206,9 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_) // output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
forwardSelfAttn(sess, decoder_output, input_tensors, layer); forwardSelfAttn(sess, decoder_output, input_tensors, layer);
invokeFusedAddResidualRMSNorm(decoder_input, invokeFusedAddBiasResidualRMSNorm(decoder_input,
decoder_output, decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights, decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_, rmsnorm_eps_,
sess.batch_size, sess.batch_size,
...@@ -219,8 +221,9 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou ...@@ -219,8 +221,9 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>(); input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddResidualRMSNorm(decoder_input, // invokeFusedAddBiasResidualRMSNorm(decoder_input, //
decoder_output, decoder_output,
decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight, scale_weight,
rmsnorm_eps_, rmsnorm_eps_,
sess.batch_size, sess.batch_size,
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc
#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" #include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h"
#include "src/fastertransformer/utils/logger.h" #include "src/fastertransformer/utils/logger.h"
...@@ -25,11 +25,16 @@ ...@@ -25,11 +25,16 @@
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight( LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t hidden_units,
size_t hidden_units, size_t inter_size, WeightType weight_type, size_t tensor_para_size, size_t tensor_para_rank): size_t inter_size,
WeightType weight_type,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank):
hidden_units_(hidden_units), hidden_units_(hidden_units),
inter_size_(inter_size), inter_size_(inter_size),
weight_type_(weight_type), weight_type_(weight_type),
attn_bias_(attn_bias),
tensor_para_size_(tensor_para_size), tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank) tensor_para_rank_(tensor_para_rank)
{ {
...@@ -117,8 +122,8 @@ void LlamaDecoderLayerWeight<T>::mallocWeights() ...@@ -117,8 +122,8 @@ void LlamaDecoderLayerWeight<T>::mallocWeights()
deviceMalloc((T**)&self_attn_norm_weights, hidden_units_); deviceMalloc((T**)&self_attn_norm_weights, hidden_units_);
deviceMalloc((T**)&ffn_norm_weights, hidden_units_); deviceMalloc((T**)&ffn_norm_weights, hidden_units_);
fastertransformer::mallocWeights(self_attn_weights.qkv, false); fastertransformer::mallocWeights(self_attn_weights.qkv, attn_bias_);
fastertransformer::mallocWeights(self_attn_weights.output, false); fastertransformer::mallocWeights(self_attn_weights.output, attn_bias_);
fastertransformer::mallocWeights(ffn_weights.gating, false); fastertransformer::mallocWeights(ffn_weights.gating, false);
fastertransformer::mallocWeights(ffn_weights.intermediate, false); fastertransformer::mallocWeights(ffn_weights.intermediate, false);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h
#pragma once #pragma once
...@@ -27,8 +28,12 @@ template<typename T> ...@@ -27,8 +28,12 @@ template<typename T>
struct LlamaDecoderLayerWeight { struct LlamaDecoderLayerWeight {
public: public:
LlamaDecoderLayerWeight() = delete; LlamaDecoderLayerWeight() = delete;
LlamaDecoderLayerWeight( LlamaDecoderLayerWeight(size_t hidden_units,
size_t hidden_units, size_t inter_size, WeightType weight_type, size_t tensor_para_size, size_t tensor_para_rank); size_t inter_size,
WeightType weight_type,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank);
~LlamaDecoderLayerWeight(); ~LlamaDecoderLayerWeight();
LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete;
LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete; LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete;
...@@ -45,6 +50,7 @@ private: ...@@ -45,6 +50,7 @@ private:
size_t inter_size_; size_t inter_size_;
WeightType weight_type_; WeightType weight_type_;
size_t bit_size_; size_t bit_size_;
bool attn_bias_;
size_t tensor_para_size_; size_t tensor_para_size_;
size_t tensor_para_rank_; size_t tensor_para_rank_;
bool is_maintain_buffer_ = false; bool is_maintain_buffer_ = false;
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h" #include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
...@@ -237,7 +237,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -237,7 +237,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
fusedQKV_masked_attention_dispatch<T>( fusedQKV_masked_attention_dispatch<T>(
qkv_buf_, qkv_buf_,
nullptr, // query_weight.bias, weights->qkv.bias, // query_weight.bias,
nullptr, // relative_attention_bias, nullptr, // relative_attention_bias,
nullptr, nullptr,
nullptr, nullptr,
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc
#include "src/fastertransformer/models/llama/LlamaWeight.h" #include "src/fastertransformer/models/llama/LlamaWeight.h"
...@@ -27,6 +28,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units, ...@@ -27,6 +28,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
size_t vocab_size, size_t vocab_size,
size_t num_layer, size_t num_layer,
WeightType weight_type, WeightType weight_type,
bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank, size_t tensor_para_rank,
int prefix_cache_len): int prefix_cache_len):
...@@ -42,7 +44,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units, ...@@ -42,7 +44,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
decoder_layer_weights.reserve(num_layer_); decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) { for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>( decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(
hidden_units_, inter_size_, weight_type_, tensor_para_size_, tensor_para_rank_)); hidden_units_, inter_size_, weight_type_, attn_bias, tensor_para_size_, tensor_para_rank_));
} }
mallocWeights(); mallocWeights();
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h
#pragma once #pragma once
...@@ -32,6 +33,7 @@ struct LlamaWeight { ...@@ -32,6 +33,7 @@ struct LlamaWeight {
size_t vocab_size, size_t vocab_size,
size_t num_layer, size_t num_layer,
WeightType weight_type, WeightType weight_type,
bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank, size_t tensor_para_rank,
int prefix_cache_len); int prefix_cache_len);
......
...@@ -16,13 +16,13 @@ struct res_norm_ops_t {}; ...@@ -16,13 +16,13 @@ struct res_norm_ops_t {};
template<typename T> template<typename T>
struct res_norm_t { struct res_norm_t {
res_norm_ops_t<T> f; res_norm_ops_t<T> f;
__device__ uint4 addvec(const uint4& a, const uint4& b, float& accum) const __device__ uint4 addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const
{ {
uint4 c; uint4 c;
c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), accum)); c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum));
c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), accum)); c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum));
c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), accum)); c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum));
c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), accum)); c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum));
return c; return c;
} }
__device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const
...@@ -47,9 +47,9 @@ struct res_norm_ops_t<half> { ...@@ -47,9 +47,9 @@ struct res_norm_ops_t<half> {
auto y = __float22half2_rn(x); auto y = __float22half2_rn(x);
return reinterpret_cast<uint&>(y); return reinterpret_cast<uint&>(y);
} }
__device__ float2 add(const float2& a, const float2& b, float& accum) const __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
{ {
float2 c{a.x + b.x, a.y + b.y}; float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
accum += c.x * c.x + c.y * c.y; accum += c.x * c.x + c.y * c.y;
return c; return c;
} }
...@@ -69,9 +69,9 @@ struct res_norm_ops_t<float> { ...@@ -69,9 +69,9 @@ struct res_norm_ops_t<float> {
{ {
return reinterpret_cast<const uint&>(x); return reinterpret_cast<const uint&>(x);
} }
__device__ float add(const float& a, const float& b, float& accum) const __device__ float add(const float& a, const float& b, const float& bias, float& accum) const
{ {
float c = a + b; float c = a + b + bias;
accum += c * c; accum += c * c;
return c; return c;
} }
...@@ -100,17 +100,23 @@ __device__ T blockReduceSum(const cg::thread_block& block, T value) ...@@ -100,17 +100,23 @@ __device__ T blockReduceSum(const cg::thread_block& block, T value)
} }
template<typename T> template<typename T>
__global__ void fusedAddResidualNorm( __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
T* __restrict__ r_data, T* __restrict__ x_data, const T* __restrict__ scale, float eps, int batch_size, int n_dims) T* __restrict__ x_data,
const T* __restrict__ bias,
const T* __restrict__ scale,
float eps,
int batch_size,
int n_dims)
{ {
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto grid = cg::this_grid(); auto grid = cg::this_grid();
constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
const auto b = grid.block_rank(); const auto batch_idx = grid.block_rank();
uint4* __restrict__ r_ptr = reinterpret_cast<uint4*>(r_data + b * n_dims); uint4* __restrict__ r_ptr = reinterpret_cast<uint4*>(r_data + batch_idx * n_dims);
uint4* __restrict__ x_ptr = reinterpret_cast<uint4*>(x_data + b * n_dims); uint4* __restrict__ x_ptr = reinterpret_cast<uint4*>(x_data + batch_idx * n_dims);
const uint4* __restrict__ b_ptr = reinterpret_cast<const uint4*>(bias);
res_norm_t<T> ops; res_norm_t<T> ops;
...@@ -118,7 +124,8 @@ __global__ void fusedAddResidualNorm( ...@@ -118,7 +124,8 @@ __global__ void fusedAddResidualNorm(
for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) { for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) {
auto r = r_ptr[i]; auto r = r_ptr[i];
auto x = x_ptr[i]; auto x = x_ptr[i];
r = ops.addvec(r, x, thread_sum); uint4 b = b_ptr ? b_ptr[i] : uint4{};
r = ops.addvec(r, x, b, thread_sum);
r_ptr[i] = r; r_ptr[i] = r;
} }
...@@ -136,8 +143,8 @@ __global__ void fusedAddResidualNorm( ...@@ -136,8 +143,8 @@ __global__ void fusedAddResidualNorm(
} }
template<typename T> template<typename T>
void invokeFusedAddResidualRMSNorm( void invokeFusedAddBiasResidualRMSNorm(
T* residual, T* inout, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream) T* residual, T* inout, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream)
{ {
constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
FT_CHECK(n_dims % PACK_DIM == 0); FT_CHECK(n_dims % PACK_DIM == 0);
...@@ -146,10 +153,12 @@ void invokeFusedAddResidualRMSNorm( ...@@ -146,10 +153,12 @@ void invokeFusedAddResidualRMSNorm(
int n_threads = (n_pack + n_iter - 1) / n_iter; // adjust block size to avoid tail effect int n_threads = (n_pack + n_iter - 1) / n_iter; // adjust block size to avoid tail effect
n_threads = (n_threads + 31) / 32 * 32; // round up to the nearest multiple of warp size n_threads = (n_threads + 31) / 32 * 32; // round up to the nearest multiple of warp size
fusedAddResidualNorm<<<batch_size, n_threads, 0, stream>>>(residual, inout, scale, eps, batch_size, n_dims); fusedAddBiasResidualNorm<<<batch_size, n_threads, 0, stream>>>(
residual, inout, bias, scale, eps, batch_size, n_dims);
} }
template void invokeFusedAddResidualRMSNorm(float*, float*, const float*, float, int, int, cudaStream_t); template void
template void invokeFusedAddResidualRMSNorm(half*, half*, const half*, float, int, int, cudaStream_t); invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
} // namespace fastertransformer } // namespace fastertransformer
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
void invokeFusedAddResidualRMSNorm( void invokeFusedAddBiasResidualRMSNorm(
T* residual, T* inout, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream); T* residual, T* inout, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream);
} // namespace fastertransformer } // namespace fastertransformer
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" #include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h"
#include "3rdparty/INIReader.h" #include "3rdparty/INIReader.h"
...@@ -127,6 +128,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -127,6 +128,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0);
prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0); prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0);
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
handleMissingParams(); handleMissingParams();
...@@ -284,6 +286,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank) ...@@ -284,6 +286,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
vocab_size_, vocab_size_,
num_layer_, num_layer_,
weight_type_, weight_type_,
attn_bias_,
tensor_para_size_, tensor_para_size_,
tensor_para_rank, tensor_para_rank,
prefix_cache_len_); prefix_cache_len_);
...@@ -297,14 +300,14 @@ std::string LlamaTritonModel<T>::toString() ...@@ -297,14 +300,14 @@ std::string LlamaTritonModel<T>::toString()
std::stringstream ss; std::stringstream ss;
ss << "Model: " ss << "Model: "
<< "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_
<< "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nmax_batch_size: " << max_batch_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nattn_bias: " << attn_bias_
<< "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_ << "\nmax_batch_size: " << max_batch_size_ << "\nmax_context_token_num: " << max_context_token_num_
<< "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_ << "\nsession_len: " << session_len_ << "\nstep_length: " << step_length_
<< "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_ << "\ncache_max_entry_count: " << cache_max_entry_count_ << "\ncache_chunk_size: " << cache_chunk_size_
<< "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_ << "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_
<< "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_
<< "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_
<< "\nmodel_dir: " << model_dir_ << std::endl; << "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ << std::endl;
return ss.str(); return ss.str();
} }
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h
#pragma once #pragma once
...@@ -91,6 +92,7 @@ private: ...@@ -91,6 +92,7 @@ private:
size_t tensor_para_size_; size_t tensor_para_size_;
size_t pipeline_para_size_; size_t pipeline_para_size_;
ft::WeightType weight_type_; ft::WeightType weight_type_;
bool attn_bias_;
size_t prefix_cache_len_{}; size_t prefix_cache_len_{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment