Unverified Commit fe46dac2 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Add lint action (#32)

* temp

* fix lint

* csrc->src

* remove clang-format

* skip .rst

* skip doc

* clang-format

version

version

* mat_B
parent e8ab4ba3
...@@ -98,19 +98,19 @@ template void invokeDecodingInitialize(bool* finished, ...@@ -98,19 +98,19 @@ template void invokeDecodingInitialize(bool* finished,
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts // PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T> template<typename T>
__global__ void embeddingLookupPosEncoding(T* from_tensor, __global__ void embeddingLookupPosEncoding(T* from_tensor,
const T* embedding_table, const T* embedding_table,
const T* position_encoding, const T* position_encoding,
const int* all_ids, const int* all_ids,
const int* padding_count, const int* padding_count,
const int* input_lengths, const int* input_lengths,
const int local_token_num, const int local_token_num,
const int64_t hidden_units, const int64_t hidden_units,
const int step, const int step,
const int max_input_length, const int max_input_length,
const int token_num, const int token_num,
const int ite, const int ite,
const T scale) const T scale)
{ {
// 1. lookup from embedding table // 1. lookup from embedding table
// 2. multiply scale // 2. multiply scale
......
...@@ -242,18 +242,18 @@ __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLo ...@@ -242,18 +242,18 @@ __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLo
// embedding lookup from word ids [batch, beam, length] (part of [batch, beam, max_input_length]), [vocab, // embedding lookup from word ids [batch, beam, length] (part of [batch, beam, max_input_length]), [vocab,
// hidden] and [batch, max_prefix_soft_prompt_length, hidden] to generate embedding [batch, beam, length + // hidden] and [batch, max_prefix_soft_prompt_length, hidden] to generate embedding [batch, beam, length +
// max_prefix_soft_prompt_length, hidden] // max_prefix_soft_prompt_length, hidden]
int tmp_index = index; int tmp_index = index;
const int hidden_id = tmp_index % param.hidden_units; const int hidden_id = tmp_index % param.hidden_units;
tmp_index = (tmp_index - hidden_id) / param.hidden_units; tmp_index = (tmp_index - hidden_id) / param.hidden_units;
const int seq_id = tmp_index % (param.max_prefix_soft_prompt_length + param.max_input_length); const int seq_id = tmp_index % (param.max_prefix_soft_prompt_length + param.max_input_length);
tmp_index = (tmp_index - seq_id) / (param.max_prefix_soft_prompt_length + param.max_input_length); tmp_index = (tmp_index - seq_id) / (param.max_prefix_soft_prompt_length + param.max_input_length);
const int beam_id = tmp_index % param.beam_width; const int beam_id = tmp_index % param.beam_width;
tmp_index = (tmp_index - beam_id) / param.beam_width; tmp_index = (tmp_index - beam_id) / param.beam_width;
const int batch_id = tmp_index % param.batch_size; const int batch_id = tmp_index % param.batch_size;
const int64_t hidden_units = param.hidden_units; const int64_t hidden_units = param.hidden_units;
T embedding = T embedding =
(seq_id < param.prefix_soft_prompt_lengths[batch_id]) ? (seq_id < param.prefix_soft_prompt_lengths[batch_id]) ?
(T)param.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * hidden_units (T)param.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * hidden_units
+ seq_id * hidden_units + hidden_id] : + seq_id * hidden_units + hidden_id] :
param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length
+ beam_id * param.max_input_length + beam_id * param.max_input_length
......
...@@ -21,50 +21,46 @@ ...@@ -21,50 +21,46 @@
#else #else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#endif #endif
#include <cuda_fp16.h>
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include <float.h> #include <float.h>
#include <type_traits> #include <type_traits>
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
namespace fastertransformer { namespace fastertransformer {
template <int VPT> template<int VPT>
struct BytesToType; struct BytesToType;
template <> template<>
struct BytesToType<2> struct BytesToType<2> {
{
using type = uint16_t; using type = uint16_t;
}; };
template <> template<>
struct BytesToType<4> struct BytesToType<4> {
{
using type = uint32_t; using type = uint32_t;
}; };
template <> template<>
struct BytesToType<8> struct BytesToType<8> {
{
using type = uint64_t; using type = uint64_t;
}; };
template <> template<>
struct BytesToType<16> struct BytesToType<16> {
{
using type = float4; using type = float4;
}; };
template <int Bytes> template<int Bytes>
__device__ inline void copy(const void* local, void* data) __device__ inline void copy(const void* local, void* data)
{ {
using T = typename BytesToType<Bytes>::type; using T = typename BytesToType<Bytes>::type;
const T* in = static_cast<const T*>(local); const T* in = static_cast<const T*>(local);
T* out = static_cast<T*>(data); T* out = static_cast<T*>(data);
*out = *in; *out = *in;
} }
static const float HALF_FLT_MAX = 65504.F; static const float HALF_FLT_MAX = 65504.F;
...@@ -134,7 +130,6 @@ __inline__ __device__ T blockReduceMax(T val) ...@@ -134,7 +130,6 @@ __inline__ __device__ T blockReduceMax(T val)
return val; return val;
} }
/* Calculate the maximum of all elements in a block */ /* Calculate the maximum of all elements in a block */
template<typename T> template<typename T>
__inline__ __device__ T blockAllReduceMax(T val) __inline__ __device__ T blockAllReduceMax(T val)
......
...@@ -149,7 +149,7 @@ void invokeLengthCriterion(bool* finished, ...@@ -149,7 +149,7 @@ void invokeLengthCriterion(bool* finished,
h_pinned_finished_sum_[0] = -1; h_pinned_finished_sum_[0] = -1;
length_criterion<<<grid, block, 0, stream>>>( length_criterion<<<grid, block, 0, stream>>>(
finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step); finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step);
while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {}; while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {};
sync_check_cuda_error(); sync_check_cuda_error();
......
...@@ -1472,7 +1472,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1472,7 +1472,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx); k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
} }
} }
if (!is_masked && !q_buf) { // also skip modifing QKV if q/k/v_buf are present if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q; *reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k; *reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v; *reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
......
...@@ -23,4 +23,4 @@ set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) ...@@ -23,4 +23,4 @@ set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(DynamicDecodeLayer PUBLIC -lcudart target_link_libraries(DynamicDecodeLayer PUBLIC -lcudart
TopKSamplingLayer TopPSamplingLayer TopKSamplingLayer TopPSamplingLayer
OnlineBeamSearchLayer BeamSearchLayer ban_bad_words stop_criteria OnlineBeamSearchLayer BeamSearchLayer ban_bad_words stop_criteria
gpt_kernels tensor nvtx_utils) gpt_kernels tensor nvtx_utils)
\ No newline at end of file
/* /*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/fastertransformer/layers/FfnFP8Layer.h" #include "src/fastertransformer/layers/FfnFP8Layer.h"
#include "src/fastertransformer/kernels/activation_fp8_kernels.h" #include "src/fastertransformer/kernels/activation_fp8_kernels.h"
#include "src/fastertransformer/utils/cublasFP8MMWrapper.h" #include "src/fastertransformer/utils/cublasFP8MMWrapper.h"
#include "src/fastertransformer/utils/nvtx_utils.h" #include "src/fastertransformer/utils/nvtx_utils.h"
namespace fastertransformer { namespace fastertransformer {
template<typename T1, typename T2> template<typename T1, typename T2>
void FfnFP8Layer<T1, T2>::forward(TensorMap* output_tensors, void FfnFP8Layer<T1, T2>::forward(TensorMap* output_tensors,
TensorMap* input_tensors, TensorMap* input_tensors,
const FfnFP8Weight<T1, T2>* ffn_weights) const FfnFP8Weight<T1, T2>* ffn_weights)
{ {
// input tensors: // input tensors:
// input_hidden_state [token_num, d_model], // input_hidden_state [token_num, d_model],
// output tensors: // output tensors:
// output_hidden_state [token_num, d_model], // output_hidden_state [token_num, d_model],
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
FT_CHECK(input_tensors->size() == 1); FT_CHECK(input_tensors->size() == 1);
FT_CHECK(output_tensors->size() == 1); FT_CHECK(output_tensors->size() == 1);
const int m = input_tensors->at("input_hidden_state").shape[0]; const int m = input_tensors->at("input_hidden_state").shape[0];
const int d_model = input_tensors->at("input_hidden_state").shape[1]; const int d_model = input_tensors->at("input_hidden_state").shape[1];
const T1* input_hidden_state = input_tensors->at("input_hidden_state").getPtr<T1>(); const T1* input_hidden_state = input_tensors->at("input_hidden_state").getPtr<T1>();
Tensor output_tensor = output_tensors->at("output_hidden_state"); Tensor output_tensor = output_tensors->at("output_hidden_state");
allocateBuffer(m); allocateBuffer(m);
#ifdef FUSE_GEMM_ACT #ifdef FUSE_GEMM_ACT
if (fp8_mode_ == 1) { if (fp8_mode_ == 1) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(inter_buf_bf16_, ->Gemm(inter_buf_bf16_,
(int)1, (int)1,
(int)m, (int)m,
(int)inter_size_, (int)inter_size_,
(int)d_model, (int)d_model,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.input_scale, ffn_weights->intermediate_weight.input_scale,
ffn_weights->intermediate_weight.per_channel_scale_min, // identity_scale ffn_weights->intermediate_weight.per_channel_scale_min, // identity_scale
stream_); stream_);
invokeAddBiasActivation(m, invokeAddBiasActivation(m,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
ffn_weights->intermediate_weight.output_scale, ffn_weights->intermediate_weight.output_scale,
ffn_weights->intermediate_weight.scale, ffn_weights->intermediate_weight.scale,
ffn_weights->intermediate_weight.per_channel_scale_min, ffn_weights->intermediate_weight.per_channel_scale_min,
ffn_weights->output_weight.input_scale_inv); ffn_weights->output_weight.input_scale_inv);
} }
else if (fp8_mode_ == 2) { else if (fp8_mode_ == 2) {
#ifdef USE_QGMMA #ifdef USE_QGMMA
if (getActivationType() == ActivationType::Gelu) { if (getActivationType() == ActivationType::Gelu) {
PUSH_RANGE("FFN gemm 1 bias gelu"); PUSH_RANGE("FFN gemm 1 bias gelu");
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Conv1x1Gemm<false, true>(inter_buf_, ->Conv1x1Gemm<false, true>(inter_buf_,
m, m,
inter_size_, inter_size_,
d_model, d_model,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
*(ffn_weights->intermediate_weight.input_h_scale), // scale_a, *(ffn_weights->intermediate_weight.input_h_scale), // scale_a,
*(ffn_weights->intermediate_weight.weight_h_scale), // scale_b, *(ffn_weights->intermediate_weight.weight_h_scale), // scale_b,
*(ffn_weights->output_weight.input_h_scale_inv), // scale_d, *(ffn_weights->output_weight.input_h_scale_inv), // scale_d,
stream_); stream_);
POP_RANGE; POP_RANGE;
} }
else if (getActivationType() == ActivationType::Relu) { else if (getActivationType() == ActivationType::Relu) {
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Conv1x1Gemm<true, false>(inter_buf_, ->Conv1x1Gemm<true, false>(inter_buf_,
m, m,
inter_size_, inter_size_,
d_model, d_model,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
*(ffn_weights->intermediate_weight.input_h_scale), // scale_a, *(ffn_weights->intermediate_weight.input_h_scale), // scale_a,
*(ffn_weights->intermediate_weight.weight_h_scale), // scale_b, *(ffn_weights->intermediate_weight.weight_h_scale), // scale_b,
*(ffn_weights->output_weight.input_h_scale_inv), // scale_d, *(ffn_weights->output_weight.input_h_scale_inv), // scale_d,
stream_); stream_);
} }
#else // USE_QGMMA #else // USE_QGMMA
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
if (getActivationType() == ActivationType::Gelu) { if (getActivationType() == ActivationType::Gelu) {
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
->Gemm_Bias_Act<false, true>(inter_buf_bf16_, ->Gemm_Bias_Act<false, true>(inter_buf_bf16_,
#else // FP8_GEMM_OUTPUT_QUANT_DISABLE #else // FP8_GEMM_OUTPUT_QUANT_DISABLE
->Gemm_Bias_Act<false, true>(inter_buf_, ->Gemm_Bias_Act<false, true>(inter_buf_,
#endif // FP8_GEMM_OUTPUT_QUANT_DISABLE #endif // FP8_GEMM_OUTPUT_QUANT_DISABLE
(int)1, (int)1,
(int)m, (int)m,
(int)inter_size_, (int)inter_size_,
(int)d_model, (int)d_model,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.input_scale, ffn_weights->intermediate_weight.input_scale,
ffn_weights->intermediate_weight.weight_scale, ffn_weights->intermediate_weight.weight_scale,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
ffn_weights->intermediate_weight.output_scale, ffn_weights->intermediate_weight.output_scale,
stream_); stream_);
} }
else if (getActivationType() == ActivationType::Relu) { else if (getActivationType() == ActivationType::Relu) {
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
->Gemm_Bias_Act<true, false>(inter_buf_bf16_, ->Gemm_Bias_Act<true, false>(inter_buf_bf16_,
#else // FP8_GEMM_OUTPUT_QUANT_DISABLE #else // FP8_GEMM_OUTPUT_QUANT_DISABLE
->Gemm_Bias_Act<true, false>(inter_buf_, ->Gemm_Bias_Act<true, false>(inter_buf_,
#endif // #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE #endif // #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
(int)1, (int)1,
(int)m, (int)m,
(int)inter_size_, (int)inter_size_,
(int)d_model, (int)d_model,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.input_scale, ffn_weights->intermediate_weight.input_scale,
ffn_weights->intermediate_weight.weight_scale, ffn_weights->intermediate_weight.weight_scale,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
ffn_weights->intermediate_weight.output_scale, ffn_weights->intermediate_weight.output_scale,
stream_); stream_);
} }
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
invokeQuantizeMatrix<T1, T2, QUANTIZE_MODE::PER_TENSOR>( invokeQuantizeMatrix<T1, T2, QUANTIZE_MODE::PER_TENSOR>(
inter_buf_, ffn_weights->output_weight.input_scale_inv, inter_buf_bf16_, m * inter_size_, 1, stream_); inter_buf_, ffn_weights->output_weight.input_scale_inv, inter_buf_bf16_, m * inter_size_, 1, stream_);
#endif FP8_GEMM_OUTPUT_QUANT_DISABLE #endif FP8_GEMM_OUTPUT_QUANT_DISABLE
#endif // USE_QGMMA #endif // USE_QGMMA
} }
#else // FUSE_GEMM_ACT #else // FUSE_GEMM_ACT
PUSH_RANGE("FFN gemm 1"); PUSH_RANGE("FFN gemm 1");
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
int m_tmp = m; int m_tmp = m;
if (m_tmp % 8 != 0) { if (m_tmp % 8 != 0) {
m_tmp = (m_tmp / 8 + 1) * 8; m_tmp = (m_tmp / 8 + 1) * 8;
} }
const int m_padded = m_tmp; const int m_padded = m_tmp;
if (sparse_ && cublas_wrapper_->isUseSparse(1, inter_size_, m, d_model)) { if (sparse_ && cublas_wrapper_->isUseSparse(1, inter_size_, m, d_model)) {
FT_CHECK(false); FT_CHECK(false);
// cublas_wrapper_->SpGemm(CUBLAS_OP_N, // cublas_wrapper_->SpGemm(CUBLAS_OP_N,
// CUBLAS_OP_N, // CUBLAS_OP_N,
// inter_size_, // inter_size_,
// m_padded, // m_padded,
// d_model, // d_model,
// ffn_weights->intermediate_weight.sp_kernel, // ffn_weights->intermediate_weight.sp_kernel,
// input_hidden_state, // input_hidden_state,
// inter_buf_); // inter_buf_);
} }
else { else {
#endif // SPARSITY_ENABLED #endif // SPARSITY_ENABLED
if (fp8_mode_ == 1) { if (fp8_mode_ == 1) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(inter_buf_bf16_, ->Gemm(inter_buf_bf16_,
(int)1, (int)1,
(int)m, (int)m,
(int)inter_size_, (int)inter_size_,
(int)d_model, (int)d_model,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.input_scale, ffn_weights->intermediate_weight.input_scale,
ffn_weights->intermediate_weight.per_channel_scale_min, // identity_scale ffn_weights->intermediate_weight.per_channel_scale_min, // identity_scale
stream_); stream_);
} }
else if (fp8_mode_ == 2) { else if (fp8_mode_ == 2) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(inter_buf_bf16_, ->Gemm(inter_buf_bf16_,
(int)1, (int)1,
(int)m, (int)m,
(int)inter_size_, (int)inter_size_,
(int)d_model, (int)d_model,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
input_hidden_state, input_hidden_state,
ffn_weights->intermediate_weight.kernel, ffn_weights->intermediate_weight.kernel,
ffn_weights->intermediate_weight.input_scale, ffn_weights->intermediate_weight.input_scale,
ffn_weights->intermediate_weight.weight_scale, ffn_weights->intermediate_weight.weight_scale,
stream_); stream_);
} }
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
} }
#endif // SPARSITY_ENABLED #endif // SPARSITY_ENABLED
POP_RANGE; POP_RANGE;
PUSH_RANGE("FFN add bias act"); PUSH_RANGE("FFN add bias act");
if (fp8_mode_ == 1) { if (fp8_mode_ == 1) {
invokeAddBiasActivation(m, invokeAddBiasActivation(m,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
ffn_weights->intermediate_weight.output_scale, ffn_weights->intermediate_weight.output_scale,
ffn_weights->intermediate_weight.scale, ffn_weights->intermediate_weight.scale,
ffn_weights->intermediate_weight.per_channel_scale_min, ffn_weights->intermediate_weight.per_channel_scale_min,
ffn_weights->output_weight.input_scale_inv); ffn_weights->output_weight.input_scale_inv);
} }
else if (fp8_mode_ == 2) { else if (fp8_mode_ == 2) {
invokeAddBiasActivation(m, invokeAddBiasActivation(m,
ffn_weights->intermediate_weight.bias, ffn_weights->intermediate_weight.bias,
ffn_weights->intermediate_weight.output_scale, ffn_weights->intermediate_weight.output_scale,
nullptr, nullptr,
nullptr, nullptr,
ffn_weights->output_weight.input_scale_inv); ffn_weights->output_weight.input_scale_inv);
} }
sync_check_cuda_error(); sync_check_cuda_error();
POP_RANGE; POP_RANGE;
#endif // FUSE_GEMM_ACT #endif // FUSE_GEMM_ACT
PUSH_RANGE("FFN gemm 2"); PUSH_RANGE("FFN gemm 2");
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
if (sparse_ && cublas_wrapper_->isUseSparse(1, d_model, m, inter_size_)) { if (sparse_ && cublas_wrapper_->isUseSparse(1, d_model, m, inter_size_)) {
FT_CHECK(false); FT_CHECK(false);
// cublas_wrapper_->SpGemm(CUBLAS_OP_N, // cublas_wrapper_->SpGemm(CUBLAS_OP_N,
// CUBLAS_OP_N, // CUBLAS_OP_N,
// d_model, // d_model,
// m_padded, // m_padded,
// inter_size_, // inter_size_,
// ffn_weights->output_weight.sp_kernel, // ffn_weights->output_weight.sp_kernel,
// inter_buf_, // inter_buf_,
// output_tensor); // output_tensor);
} }
else { else {
#endif SPARSITY_ENABLED #endif SPARSITY_ENABLED
if (fp8_mode_ == 1) { if (fp8_mode_ == 1) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
if (output_tensor.type == TYPE_BF16) { if (output_tensor.type == TYPE_BF16) {
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(output_tensor.getPtr<T2>(), ->Gemm(output_tensor.getPtr<T2>(),
(int)1, (int)1,
(int)m, (int)m,
(int)d_model, (int)d_model,
(int)inter_size_, (int)inter_size_,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
(const __nv_fp8_e4m3*)inter_buf_, (const __nv_fp8_e4m3*)inter_buf_,
(const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel,
ffn_weights->output_weight.input_scale, ffn_weights->output_weight.input_scale,
ffn_weights->identity_scale, ffn_weights->identity_scale,
stream_); stream_);
} }
else if (output_tensor.type == TYPE_FP8_E4M3) { else if (output_tensor.type == TYPE_FP8_E4M3) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(output_tensor.getPtr<T1>(), ->Gemm(output_tensor.getPtr<T1>(),
(int)1, (int)1,
(int)m, (int)m,
(int)d_model, (int)d_model,
(int)inter_size_, (int)inter_size_,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
(const __nv_fp8_e4m3*)inter_buf_, (const __nv_fp8_e4m3*)inter_buf_,
(const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel,
ffn_weights->output_weight.input_scale, ffn_weights->output_weight.input_scale,
ffn_weights->output_weight.per_channel_scale_min, ffn_weights->output_weight.per_channel_scale_min,
ffn_weights->output_weight.output_scale_inv, ffn_weights->output_weight.output_scale_inv,
stream_); stream_);
} }
else { else {
FT_CHECK(false); FT_CHECK(false);
} }
} }
else if (fp8_mode_ == 2) { else if (fp8_mode_ == 2) {
if (output_tensor.type == TYPE_BF16) { if (output_tensor.type == TYPE_BF16) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(output_tensor.getPtr<T2>(), ->Gemm(output_tensor.getPtr<T2>(),
(int)1, (int)1,
(int)m, (int)m,
(int)d_model, (int)d_model,
(int)inter_size_, (int)inter_size_,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
(const __nv_fp8_e4m3*)inter_buf_, (const __nv_fp8_e4m3*)inter_buf_,
(const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel,
ffn_weights->output_weight.input_scale, ffn_weights->output_weight.input_scale,
ffn_weights->output_weight.weight_scale, ffn_weights->output_weight.weight_scale,
stream_); stream_);
} }
else if (output_tensor.type == TYPE_FP8_E4M3) { else if (output_tensor.type == TYPE_FP8_E4M3) {
// It looks like conv1x1Gemm does not bring better performance for this gemm // It looks like conv1x1Gemm does not bring better performance for this gemm
// because the k dimension of this gemm is large // because the k dimension of this gemm is large
// #ifdef USE_QGMMA // #ifdef USE_QGMMA
// reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) // reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
// ->Conv1x1Gemm<false, false>(output_tensor.getPtr<T1>(), // ->Conv1x1Gemm<false, false>(output_tensor.getPtr<T1>(),
// m, // m,
// d_model, // d_model,
// inter_size_, // inter_size_,
// inter_buf_, // inter_buf_,
// ffn_weights->output_weight.kernel, // ffn_weights->output_weight.kernel,
// ffn_weights->output_weight.bias, // ffn_weights->output_weight.bias,
// *(ffn_weights->output_weight.input_h_scale), // // *(ffn_weights->output_weight.input_h_scale), //
// scale_a, // scale_a,
// *(ffn_weights->output_weight.weight_h_scale), // // *(ffn_weights->output_weight.weight_h_scale), //
// scale_b, // scale_b,
// *(ffn_weights->output_weight.output_h_scale_inv), // // *(ffn_weights->output_weight.output_h_scale_inv), //
// scale_d, stream_); // scale_d, stream_);
// #else // USE_QGMMA // #else // USE_QGMMA
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_) reinterpret_cast<cublasFP8MMWrapper*>(cublas_wrapper_)
->Gemm(output_tensor.getPtr<T1>(), ->Gemm(output_tensor.getPtr<T1>(),
(int)1, (int)1,
(int)m, (int)m,
(int)d_model, (int)d_model,
(int)inter_size_, (int)inter_size_,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
(int64_t)0, (int64_t)0,
&alpha, &alpha,
&beta, &beta,
(const __nv_fp8_e4m3*)inter_buf_, (const __nv_fp8_e4m3*)inter_buf_,
(const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel,
ffn_weights->output_weight.input_scale, ffn_weights->output_weight.input_scale,
ffn_weights->output_weight.weight_scale, ffn_weights->output_weight.weight_scale,
ffn_weights->output_weight.output_scale_inv, ffn_weights->output_weight.output_scale_inv,
stream_); stream_);
// #endif // USE_QGMMA // #endif // USE_QGMMA
} }
else { else {
FT_CHECK(false); FT_CHECK(false);
} }
} }
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
} }
#endif // SPARSITY_ENABLED #endif // SPARSITY_ENABLED
POP_RANGE; POP_RANGE;
sync_check_cuda_error(); sync_check_cuda_error();
if (is_free_buffer_after_forward_ == true) { if (is_free_buffer_after_forward_ == true) {
freeBuffer(); freeBuffer();
} }
sync_check_cuda_error(); sync_check_cuda_error();
} }
template<typename T1, typename T2> template<typename T1, typename T2>
FfnFP8Layer<T1, T2>::FfnFP8Layer(size_t inter_size, FfnFP8Layer<T1, T2>::FfnFP8Layer(size_t inter_size,
int fp8_mode, int fp8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse): bool sparse):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse),
inter_size_(inter_size), inter_size_(inter_size),
fp8_mode_(fp8_mode) fp8_mode_(fp8_mode)
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
} }
template<typename T1, typename T2> template<typename T1, typename T2>
FfnFP8Layer<T1, T2>::FfnFP8Layer(FfnFP8Layer<T1, T2> const& ffn_layer): FfnFP8Layer<T1, T2>::FfnFP8Layer(FfnFP8Layer<T1, T2> const& ffn_layer):
BaseLayer(ffn_layer.stream_, BaseLayer(ffn_layer.stream_,
ffn_layer.cublas_wrapper_, ffn_layer.cublas_wrapper_,
ffn_layer.allocator_, ffn_layer.allocator_,
ffn_layer.is_free_buffer_after_forward_, ffn_layer.is_free_buffer_after_forward_,
ffn_layer.cuda_device_prop_, ffn_layer.cuda_device_prop_,
ffn_layer.sparse_), ffn_layer.sparse_),
inter_size_(ffn_layer.inter_size_), inter_size_(ffn_layer.inter_size_),
fp8_mode_(ffn_layer.fp8_mode_) fp8_mode_(ffn_layer.fp8_mode_)
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
} }
template<typename T1, typename T2> template<typename T1, typename T2>
FfnFP8Layer<T1, T2>::~FfnFP8Layer() FfnFP8Layer<T1, T2>::~FfnFP8Layer()
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
cublas_wrapper_ = nullptr; cublas_wrapper_ = nullptr;
freeBuffer(); freeBuffer();
} }
template<typename T1, typename T2> template<typename T1, typename T2>
void FfnFP8Layer<T1, T2>::allocateBuffer() void FfnFP8Layer<T1, T2>::allocateBuffer()
{ {
FT_CHECK(false); FT_CHECK(false);
} }
template<typename T1, typename T2> template<typename T1, typename T2>
void FfnFP8Layer<T1, T2>::allocateBuffer(size_t token_num) void FfnFP8Layer<T1, T2>::allocateBuffer(size_t token_num)
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
inter_buf_ = (T1*)allocator_->reMalloc(inter_buf_, sizeof(T1) * token_num * inter_size_, false); inter_buf_ = (T1*)allocator_->reMalloc(inter_buf_, sizeof(T1) * token_num * inter_size_, false);
inter_buf_bf16_ = (T2*)allocator_->reMalloc(inter_buf_bf16_, sizeof(T2) * token_num * inter_size_, false); inter_buf_bf16_ = (T2*)allocator_->reMalloc(inter_buf_bf16_, sizeof(T2) * token_num * inter_size_, false);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
template<typename T1, typename T2> template<typename T1, typename T2>
void FfnFP8Layer<T1, T2>::freeBuffer() void FfnFP8Layer<T1, T2>::freeBuffer()
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)(&inter_buf_)); allocator_->free((void**)(&inter_buf_));
allocator_->free((void**)(&inter_buf_bf16_)); allocator_->free((void**)(&inter_buf_bf16_));
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
} }
} }
template class FfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>; template class FfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>;
template<typename T1, typename T2> template<typename T1, typename T2>
GeluFfnFP8Layer<T1, T2>::GeluFfnFP8Layer(size_t inter_size, GeluFfnFP8Layer<T1, T2>::GeluFfnFP8Layer(size_t inter_size,
int fp8_mode, int fp8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse): bool sparse):
FfnFP8Layer<T1, T2>(inter_size, fp8_mode, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse) FfnFP8Layer<T1, T2>(inter_size, fp8_mode, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse)
{ {
} }
template<typename T1, typename T2> template<typename T1, typename T2>
GeluFfnFP8Layer<T1, T2>::GeluFfnFP8Layer(GeluFfnFP8Layer<T1, T2> const& gelu_ffn_layer): GeluFfnFP8Layer<T1, T2>::GeluFfnFP8Layer(GeluFfnFP8Layer<T1, T2> const& gelu_ffn_layer):
FfnFP8Layer<T1, T2>(gelu_ffn_layer) FfnFP8Layer<T1, T2>(gelu_ffn_layer)
{ {
} }
template<typename T1, typename T2> template<typename T1, typename T2>
void GeluFfnFP8Layer<T1, T2>::invokeAddBiasActivation(const int m, void GeluFfnFP8Layer<T1, T2>::invokeAddBiasActivation(const int m,
const T2* bias, const T2* bias,
const float* input_scale, const float* input_scale,
const float* input_scale_2, const float* input_scale_2,
const float* input_scale_2_min, const float* input_scale_2_min,
const float* output_scale) const float* output_scale)
{ {
FP8ActivationParam<T1, T2> param{inter_buf_bf16_, FP8ActivationParam<T1, T2> param{inter_buf_bf16_,
inter_buf_, inter_buf_,
bias, bias,
input_scale, input_scale,
input_scale_2, input_scale_2,
input_scale_2_min, input_scale_2_min,
output_scale, output_scale,
(uint32_t)m, (uint32_t)m,
(uint32_t)inter_size_, (uint32_t)inter_size_,
stream_}; stream_};
invokeFP8AddBiasGelu<T1, T2>(param); invokeFP8AddBiasGelu<T1, T2>(param);
} }
template class GeluFfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>; template class GeluFfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>;
template<typename T1, typename T2> template<typename T1, typename T2>
ReluFfnFP8Layer<T1, T2>::ReluFfnFP8Layer(size_t inter_size, ReluFfnFP8Layer<T1, T2>::ReluFfnFP8Layer(size_t inter_size,
int fp8_mode, int fp8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse): bool sparse):
FfnFP8Layer<T1, T2>(inter_size, fp8_mode, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse) FfnFP8Layer<T1, T2>(inter_size, fp8_mode, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse)
{ {
} }
template<typename T1, typename T2> template<typename T1, typename T2>
ReluFfnFP8Layer<T1, T2>::ReluFfnFP8Layer(ReluFfnFP8Layer<T1, T2> const& relu_ffn_layer): ReluFfnFP8Layer<T1, T2>::ReluFfnFP8Layer(ReluFfnFP8Layer<T1, T2> const& relu_ffn_layer):
FfnFP8Layer<T1, T2>(relu_ffn_layer) FfnFP8Layer<T1, T2>(relu_ffn_layer)
{ {
} }
template<typename T1, typename T2> template<typename T1, typename T2>
void ReluFfnFP8Layer<T1, T2>::invokeAddBiasActivation(const int m, void ReluFfnFP8Layer<T1, T2>::invokeAddBiasActivation(const int m,
const T2* bias, const T2* bias,
const float* input_scale, const float* input_scale,
const float* input_scale_2, const float* input_scale_2,
const float* input_scale_2_min, const float* input_scale_2_min,
const float* output_scale) const float* output_scale)
{ {
FP8ActivationParam<T1, T2> param{inter_buf_bf16_, FP8ActivationParam<T1, T2> param{inter_buf_bf16_,
inter_buf_, inter_buf_,
bias, bias,
input_scale, input_scale,
input_scale_2, input_scale_2,
input_scale_2_min, input_scale_2_min,
output_scale, output_scale,
(uint32_t)m, (uint32_t)m,
(uint32_t)inter_size_, (uint32_t)inter_size_,
stream_}; stream_};
invokeFP8AddBiasRelu<T1, T2>(param); invokeFP8AddBiasRelu<T1, T2>(param);
} }
template class ReluFfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>; template class ReluFfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>;
} // namespace fastertransformer } // namespace fastertransformer
/* /*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "FfnWeight.h" #include "FfnWeight.h"
#include "src/fastertransformer/utils/ScaleList.h" #include "src/fastertransformer/utils/ScaleList.h"
namespace fastertransformer { namespace fastertransformer {
template<typename T1, typename T2> template<typename T1, typename T2>
struct FfnFP8Weight: FfnWeight<T1, T2> { struct FfnFP8Weight: FfnWeight<T1, T2> {
ScaleList* scale_list_ptr; ScaleList* scale_list_ptr;
float* identity_scale; float* identity_scale;
float* identity_h_scale; float* identity_h_scale;
}; };
} // namespace fastertransformer } // namespace fastertransformer
/* /*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "FfnWeight.h" #include "FfnWeight.h"
#include "src/fastertransformer/utils/ScaleList.h" #include "src/fastertransformer/utils/ScaleList.h"
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
struct FfnINT8Weight: FfnWeight<T> { struct FfnINT8Weight: FfnWeight<T> {
ScaleList* scale_list_ptr; ScaleList* scale_list_ptr;
}; };
} // namespace fastertransformer } // namespace fastertransformer
/* /*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "FfnLayerINT8.h" #include "FfnLayerINT8.h"
#include "src/fastertransformer/utils/nvtx_utils.h" #include "src/fastertransformer/utils/nvtx_utils.h"
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
void FfnLayerINT8<T>::forward(std::vector<fastertransformer::Tensor>* output_tensors, void FfnLayerINT8<T>::forward(std::vector<fastertransformer::Tensor>* output_tensors,
const std::vector<fastertransformer::Tensor>* input_tensors, const std::vector<fastertransformer::Tensor>* input_tensors,
const FfnWeight<T>* ffn_weights) const FfnWeight<T>* ffn_weights)
{ {
// input_tensors: [input (token_num, hidden_dimension)] // input_tensors: [input (token_num, hidden_dimension)]
// output_tensors: [output (token_num, hidden_dimension)] // output_tensors: [output (token_num, hidden_dimension)]
ScaleList* scale_list = ((const FfnINT8Weight<T>*)ffn_weights)->scale_list_ptr; ScaleList* scale_list = ((const FfnINT8Weight<T>*)ffn_weights)->scale_list_ptr;
cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_;
FT_CHECK(isValidTokenNum(input_tensors->at(0).shape[0])); FT_CHECK(isValidTokenNum(input_tensors->at(0).shape[0]));
allocateBuffer(); allocateBuffer();
const int m = static_cast<int>(input_tensors->at(0).shape[0]); const int m = static_cast<int>(input_tensors->at(0).shape[0]);
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
int m_tmp = m; int m_tmp = m;
if (m_tmp % 16 != 0) { if (m_tmp % 16 != 0) {
m_tmp = (m_tmp / 16 + 1) * 16; m_tmp = (m_tmp / 16 + 1) * 16;
} }
const int m_padded = m_tmp; const int m_padded = m_tmp;
#endif #endif
int32_t* output_tensor = output_tensors->at(0).getPtr<int32_t>(); int32_t* output_tensor = output_tensors->at(0).getPtr<int32_t>();
const int8_t* input_tensor = input_tensors->at(0).getPtr<const int8_t>(); const int8_t* input_tensor = input_tensors->at(0).getPtr<const int8_t>();
PUSH_RANGE("FFN gemm 1"); PUSH_RANGE("FFN gemm 1");
if (int8_mode_ == 1) { if (int8_mode_ == 1) {
cublas_wrapper->Gemm(inter_int_buf_, cublas_wrapper->Gemm(inter_int_buf_,
1, 1,
m, m,
inter_size_, inter_size_,
hidden_units_, hidden_units_,
0, 0,
0, 0,
0, 0,
input_tensor, input_tensor,
(int8_t*)(ffn_weights->intermediate_weight.kernel)); (int8_t*)(ffn_weights->intermediate_weight.kernel));
} }
else if (int8_mode_ == 2 || int8_mode_ == 3) { else if (int8_mode_ == 2 || int8_mode_ == 3) {
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
if (sparse_) { if (sparse_) {
cublas_wrapper->SpGemm(inter_size_, cublas_wrapper->SpGemm(inter_size_,
m_padded, m_padded,
hidden_units_, hidden_units_,
scale_list->h_scale_list_[scale_list->p3_offset_ + 6], scale_list->h_scale_list_[scale_list->p3_offset_ + 6],
(int8_t*)(ffn_weights->intermediate_weight.sp_kernel), (int8_t*)(ffn_weights->intermediate_weight.sp_kernel),
input_tensor, input_tensor,
(int8_t*)inter_int_buf_); (int8_t*)inter_int_buf_);
} }
else { else {
#endif #endif
cublas_wrapper->Gemm((int8_t*)inter_int_buf_, cublas_wrapper->Gemm((int8_t*)inter_int_buf_,
1, 1,
m, m,
inter_size_, inter_size_,
hidden_units_, hidden_units_,
0, 0,
0, 0,
0, 0,
scale_list->h_scale_list_[scale_list->p3_offset_ + 6], scale_list->h_scale_list_[scale_list->p3_offset_ + 6],
input_tensor, input_tensor,
(int8_t*)(ffn_weights->intermediate_weight.kernel)); (int8_t*)(ffn_weights->intermediate_weight.kernel));
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
} }
#endif #endif
} }
POP_RANGE; POP_RANGE;
PUSH_RANGE("add bias act"); PUSH_RANGE("add bias act");
invokeAddBiasActivation(m, ffn_weights->intermediate_weight.bias, scale_list); invokeAddBiasActivation(m, ffn_weights->intermediate_weight.bias, scale_list);
POP_RANGE; POP_RANGE;
sync_check_cuda_error(); sync_check_cuda_error();
PUSH_RANGE("FFN gemm 2"); PUSH_RANGE("FFN gemm 2");
if (int8_mode_ == 1) { if (int8_mode_ == 1) {
cublas_wrapper->Gemm(output_tensor, cublas_wrapper->Gemm(output_tensor,
1, 1,
m, m,
hidden_units_, hidden_units_,
inter_size_, inter_size_,
0, 0,
0, 0,
0, 0,
inter_buf_, inter_buf_,
(int8_t*)(ffn_weights->output_weight.kernel)); (int8_t*)(ffn_weights->output_weight.kernel));
} }
else if (int8_mode_ == 2 || int8_mode_ == 3) { else if (int8_mode_ == 2 || int8_mode_ == 3) {
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
if (sparse_) { if (sparse_) {
cublas_wrapper->SpGemm(hidden_units_, cublas_wrapper->SpGemm(hidden_units_,
m_padded, m_padded,
inter_size_, inter_size_,
scale_list->h_scale_list_[scale_list->p3_offset_ + 7], scale_list->h_scale_list_[scale_list->p3_offset_ + 7],
(int8_t*)(ffn_weights->output_weight.sp_kernel), (int8_t*)(ffn_weights->output_weight.sp_kernel),
inter_buf_, inter_buf_,
(int8_t*)output_tensor); (int8_t*)output_tensor);
} }
else { else {
#endif #endif
cublas_wrapper->Gemm((int8_t*)output_tensor, cublas_wrapper->Gemm((int8_t*)output_tensor,
1, 1,
m, m,
hidden_units_, hidden_units_,
inter_size_, inter_size_,
0, 0,
0, 0,
0, 0,
scale_list->h_scale_list_[scale_list->p3_offset_ + 7], scale_list->h_scale_list_[scale_list->p3_offset_ + 7],
inter_buf_, inter_buf_,
(int8_t*)(ffn_weights->output_weight.kernel)); (int8_t*)(ffn_weights->output_weight.kernel));
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
} }
#endif #endif
} }
POP_RANGE; POP_RANGE;
sync_check_cuda_error(); sync_check_cuda_error();
if (is_free_buffer_after_forward_ == true) { if (is_free_buffer_after_forward_ == true) {
freeBuffer(); freeBuffer();
} }
sync_check_cuda_error(); sync_check_cuda_error();
} }
template<typename T> template<typename T>
FfnLayerINT8<T>::FfnLayerINT8(size_t max_batch_size, FfnLayerINT8<T>::FfnLayerINT8(size_t max_batch_size,
size_t max_seq_len, size_t max_seq_len,
size_t head_num, size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
int int8_mode, int int8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse): bool sparse):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
max_token_num_(max_batch_size * max_seq_len), max_token_num_(max_batch_size * max_seq_len),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
int8_mode_(int8_mode), int8_mode_(int8_mode),
sparse_(sparse) sparse_(sparse)
{ {
} }
template<typename T> template<typename T>
FfnLayerINT8<T>::FfnLayerINT8(FfnLayerINT8<T> const& ffn_layer): FfnLayerINT8<T>::FfnLayerINT8(FfnLayerINT8<T> const& ffn_layer):
BaseLayer( BaseLayer(
ffn_layer.stream_, ffn_layer.cublas_wrapper_, ffn_layer.allocator_, ffn_layer.is_free_buffer_after_forward_), ffn_layer.stream_, ffn_layer.cublas_wrapper_, ffn_layer.allocator_, ffn_layer.is_free_buffer_after_forward_),
max_token_num_(ffn_layer.max_token_num_), max_token_num_(ffn_layer.max_token_num_),
head_num_(ffn_layer.head_num_), head_num_(ffn_layer.head_num_),
size_per_head_(ffn_layer.size_per_head_), size_per_head_(ffn_layer.size_per_head_),
hidden_units_(ffn_layer.hidden_units_), hidden_units_(ffn_layer.hidden_units_),
inter_size_(ffn_layer.inter_size_), inter_size_(ffn_layer.inter_size_),
int8_mode_(ffn_layer.int8_mode_), int8_mode_(ffn_layer.int8_mode_),
sparse_(ffn_layer.sparse_) sparse_(ffn_layer.sparse_)
{ {
} }
template<typename T> template<typename T>
FfnLayerINT8<T>::~FfnLayerINT8() FfnLayerINT8<T>::~FfnLayerINT8()
{ {
cublas_wrapper_ = nullptr; cublas_wrapper_ = nullptr;
freeBuffer(); freeBuffer();
} }
template<typename T> template<typename T>
void FfnLayerINT8<T>::allocateBuffer() void FfnLayerINT8<T>::allocateBuffer()
{ {
if (is_allocate_buffer_ == false) { if (is_allocate_buffer_ == false) {
inter_int_buf_ = inter_int_buf_ =
(int32_t*)allocator_->reMalloc(inter_int_buf_, sizeof(int32_t) * max_token_num_ * inter_size_, false); (int32_t*)allocator_->reMalloc(inter_int_buf_, sizeof(int32_t) * max_token_num_ * inter_size_, false);
inter_buf_ = (int8_t*)allocator_->reMalloc(inter_buf_, sizeof(int8_t) * max_token_num_ * inter_size_, false); inter_buf_ = (int8_t*)allocator_->reMalloc(inter_buf_, sizeof(int8_t) * max_token_num_ * inter_size_, false);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
} }
template<typename T> template<typename T>
void FfnLayerINT8<T>::freeBuffer() void FfnLayerINT8<T>::freeBuffer()
{ {
if (is_allocate_buffer_ == true) { if (is_allocate_buffer_ == true) {
allocator_->free((void**)(&inter_int_buf_)); allocator_->free((void**)(&inter_int_buf_));
allocator_->free((void**)(&inter_buf_)); allocator_->free((void**)(&inter_buf_));
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
} }
} }
template<typename T> template<typename T>
bool FfnLayerINT8<T>::isValidTokenNum(size_t token_num) bool FfnLayerINT8<T>::isValidTokenNum(size_t token_num)
{ {
if (max_token_num_ == 0) { if (max_token_num_ == 0) {
max_token_num_ = token_num; max_token_num_ = token_num;
return true; return true;
} }
else { else {
return token_num <= max_token_num_; return token_num <= max_token_num_;
} }
} }
template class FfnLayerINT8<float>; template class FfnLayerINT8<float>;
template class FfnLayerINT8<half>; template class FfnLayerINT8<half>;
template<typename T> template<typename T>
GeluFfnLayerINT8<T>::GeluFfnLayerINT8(size_t max_batch_size, GeluFfnLayerINT8<T>::GeluFfnLayerINT8(size_t max_batch_size,
size_t max_seq_len, size_t max_seq_len,
size_t head_num, size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
int int8_mode, int int8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse): bool sparse):
FfnLayerINT8<T>(max_batch_size, FfnLayerINT8<T>(max_batch_size,
max_seq_len, max_seq_len,
head_num, head_num,
size_per_head, size_per_head,
inter_size, inter_size,
int8_mode, int8_mode,
stream, stream,
cublas_wrapper, cublas_wrapper,
allocator, allocator,
is_free_buffer_after_forward, is_free_buffer_after_forward,
sparse) sparse)
{ {
} }
template<typename T> template<typename T>
GeluFfnLayerINT8<T>::GeluFfnLayerINT8(GeluFfnLayerINT8<T> const& gelu_ffn_layer): FfnLayerINT8<T>(gelu_ffn_layer) GeluFfnLayerINT8<T>::GeluFfnLayerINT8(GeluFfnLayerINT8<T> const& gelu_ffn_layer): FfnLayerINT8<T>(gelu_ffn_layer)
{ {
} }
template<typename T> template<typename T>
void GeluFfnLayerINT8<T>::invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) void GeluFfnLayerINT8<T>::invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list)
{ {
if (int8_mode_ == 1) { if (int8_mode_ == 1) {
invokeAddBiasGeluCol32<T>(inter_buf_, invokeAddBiasGeluCol32<T>(inter_buf_,
inter_int_buf_, inter_int_buf_,
bias, bias,
m, m,
inter_size_, inter_size_,
stream_, stream_,
&(scale_list->d_scale_list_[scale_list->p2_offset_ + 4 * hidden_units_]), &(scale_list->d_scale_list_[scale_list->p2_offset_ + 4 * hidden_units_]),
&(scale_list->d_scale_list_[44 + 2]), &(scale_list->d_scale_list_[44 + 2]),
&(scale_list->d_scale_list_[52 + 3])); &(scale_list->d_scale_list_[52 + 3]));
} }
else if (int8_mode_ == 2 || int8_mode_ == 3) { else if (int8_mode_ == 2 || int8_mode_ == 3) {
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
if (sparse_) { if (sparse_) {
invokeAddBiasGeluRow<T>(inter_buf_, invokeAddBiasGeluRow<T>(inter_buf_,
(const int8_t*)inter_int_buf_, (const int8_t*)inter_int_buf_,
bias, bias,
m, m,
inter_size_, inter_size_,
stream_, stream_,
&(scale_list->d_scale_list_[48 + 1]), &(scale_list->d_scale_list_[48 + 1]),
&(scale_list->d_scale_list_[52 + 3])); &(scale_list->d_scale_list_[52 + 3]));
} }
else { else {
#endif #endif
invokeAddBiasGeluCol32<T>(inter_buf_, invokeAddBiasGeluCol32<T>(inter_buf_,
(const int8_t*)inter_int_buf_, (const int8_t*)inter_int_buf_,
bias, bias,
m, m,
inter_size_, inter_size_,
stream_, stream_,
&(scale_list->d_scale_list_[48 + 1]), &(scale_list->d_scale_list_[48 + 1]),
&(scale_list->d_scale_list_[52 + 3])); &(scale_list->d_scale_list_[52 + 3]));
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
} }
#endif #endif
} }
} }
template class GeluFfnLayerINT8<float>; template class GeluFfnLayerINT8<float>;
template class GeluFfnLayerINT8<half>; template class GeluFfnLayerINT8<half>;
template<typename T> template<typename T>
ReluFfnLayerINT8<T>::ReluFfnLayerINT8(size_t max_batch_size, ReluFfnLayerINT8<T>::ReluFfnLayerINT8(size_t max_batch_size,
size_t max_seq_len, size_t max_seq_len,
size_t head_num, size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
int int8_mode, int int8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward): bool is_free_buffer_after_forward):
FfnLayerINT8<T>(max_batch_size, FfnLayerINT8<T>(max_batch_size,
max_seq_len, max_seq_len,
head_num, head_num,
size_per_head, size_per_head,
inter_size, inter_size,
int8_mode, int8_mode,
stream, stream,
cublas_wrapper, cublas_wrapper,
allocator, allocator,
is_free_buffer_after_forward) is_free_buffer_after_forward)
{ {
} }
template<typename T> template<typename T>
ReluFfnLayerINT8<T>::ReluFfnLayerINT8(ReluFfnLayerINT8<T> const& relu_ffn_layer): FfnLayerINT8<T>(relu_ffn_layer) ReluFfnLayerINT8<T>::ReluFfnLayerINT8(ReluFfnLayerINT8<T> const& relu_ffn_layer): FfnLayerINT8<T>(relu_ffn_layer)
{ {
} }
template<typename T> template<typename T>
void ReluFfnLayerINT8<T>::invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) void ReluFfnLayerINT8<T>::invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list)
{ {
// TODO // TODO
} }
template class ReluFfnLayerINT8<float>; template class ReluFfnLayerINT8<float>;
template class ReluFfnLayerINT8<half>; template class ReluFfnLayerINT8<half>;
} // namespace fastertransformer } // namespace fastertransformer
/* /*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "FfnINT8Weight.h" #include "FfnINT8Weight.h"
#include "src/fastertransformer/kernels/activation_int8_kernels.h" #include "src/fastertransformer/kernels/activation_int8_kernels.h"
#include "src/fastertransformer/layers/BaseLayer.h" #include "src/fastertransformer/layers/BaseLayer.h"
#include "src/fastertransformer/utils/ScaleList.h" #include "src/fastertransformer/utils/ScaleList.h"
#include "src/fastertransformer/utils/Tensor.h" #include "src/fastertransformer/utils/Tensor.h"
#include "src/fastertransformer/utils/allocator.h" #include "src/fastertransformer/utils/allocator.h"
#include "src/fastertransformer/utils/cublasINT8MMWrapper.h" #include "src/fastertransformer/utils/cublasINT8MMWrapper.h"
#include "src/fastertransformer/utils/memory_utils.h" #include "src/fastertransformer/utils/memory_utils.h"
#include <vector> #include <vector>
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
class GeluFfnLayerINT8; class GeluFfnLayerINT8;
template<typename T> template<typename T>
class ReluFfnLayerINT8; class ReluFfnLayerINT8;
template<typename T> template<typename T>
class FfnLayerINT8: public BaseLayer { class FfnLayerINT8: public BaseLayer {
private: private:
// buffer handling // buffer handling
size_t max_token_num_ = 0; size_t max_token_num_ = 0;
// meta data // meta data
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
// calculated data // calculated data
size_t hidden_units_; size_t hidden_units_;
void allocateBuffer() override; void allocateBuffer() override;
void freeBuffer() override; void freeBuffer() override;
bool isValidTokenNum(size_t token_num); bool isValidTokenNum(size_t token_num);
protected: protected:
size_t inter_size_; size_t inter_size_;
int int8_mode_; int int8_mode_;
bool sparse_; bool sparse_;
int* inter_int_buf_; int* inter_int_buf_;
int8_t* inter_buf_; int8_t* inter_buf_;
virtual void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) = 0; virtual void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) = 0;
public: public:
FfnLayerINT8(size_t max_batch_size, FfnLayerINT8(size_t max_batch_size,
size_t max_seq_len, size_t max_seq_len,
size_t head_num, size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
int int8_mode, int int8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse = false); bool sparse = false);
FfnLayerINT8(FfnLayerINT8<T> const& ffn_layer); FfnLayerINT8(FfnLayerINT8<T> const& ffn_layer);
~FfnLayerINT8(); ~FfnLayerINT8();
void forward(std::vector<fastertransformer::Tensor>* output_tensors, void forward(std::vector<fastertransformer::Tensor>* output_tensors,
const std::vector<fastertransformer::Tensor>* input_tensors, const std::vector<fastertransformer::Tensor>* input_tensors,
const FfnWeight<T>* ffn_weights); const FfnWeight<T>* ffn_weights);
friend GeluFfnLayerINT8<T>; friend GeluFfnLayerINT8<T>;
friend ReluFfnLayerINT8<T>; friend ReluFfnLayerINT8<T>;
}; };
template<typename T> template<typename T>
class GeluFfnLayerINT8: public FfnLayerINT8<T> { class GeluFfnLayerINT8: public FfnLayerINT8<T> {
public: public:
GeluFfnLayerINT8(size_t max_batch_size, GeluFfnLayerINT8(size_t max_batch_size,
size_t max_seq_len, size_t max_seq_len,
size_t head_num, size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
int int8_mode, int int8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool sparse = false); bool sparse = false);
GeluFfnLayerINT8(GeluFfnLayerINT8<T> const& ffn_layer); GeluFfnLayerINT8(GeluFfnLayerINT8<T> const& ffn_layer);
~GeluFfnLayerINT8() = default; ~GeluFfnLayerINT8() = default;
private: private:
using FfnLayerINT8<T>::inter_int_buf_; using FfnLayerINT8<T>::inter_int_buf_;
using FfnLayerINT8<T>::inter_buf_; using FfnLayerINT8<T>::inter_buf_;
using FfnLayerINT8<T>::inter_size_; using FfnLayerINT8<T>::inter_size_;
using FfnLayerINT8<T>::stream_; using FfnLayerINT8<T>::stream_;
using FfnLayerINT8<T>::int8_mode_; using FfnLayerINT8<T>::int8_mode_;
using FfnLayerINT8<T>::sparse_; using FfnLayerINT8<T>::sparse_;
using FfnLayerINT8<T>::hidden_units_; using FfnLayerINT8<T>::hidden_units_;
void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override; void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override;
}; };
template<typename T> template<typename T>
class ReluFfnLayerINT8: public FfnLayerINT8<T> { class ReluFfnLayerINT8: public FfnLayerINT8<T> {
public: public:
ReluFfnLayerINT8(size_t max_batch_size, ReluFfnLayerINT8(size_t max_batch_size,
size_t max_seq_len, size_t max_seq_len,
size_t head_num, size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
int int8_mode, int int8_mode,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward); bool is_free_buffer_after_forward);
ReluFfnLayerINT8(ReluFfnLayerINT8<T> const& ffn_layer); ReluFfnLayerINT8(ReluFfnLayerINT8<T> const& ffn_layer);
~ReluFfnLayerINT8() = default; ~ReluFfnLayerINT8() = default;
private: private:
using FfnLayerINT8<T>::inter_int_buf_; using FfnLayerINT8<T>::inter_int_buf_;
using FfnLayerINT8<T>::inter_buf_; using FfnLayerINT8<T>::inter_buf_;
using FfnLayerINT8<T>::inter_size_; using FfnLayerINT8<T>::inter_size_;
using FfnLayerINT8<T>::stream_; using FfnLayerINT8<T>::stream_;
using FfnLayerINT8<T>::int8_mode_; using FfnLayerINT8<T>::int8_mode_;
using FfnLayerINT8<T>::hidden_units_; using FfnLayerINT8<T>::hidden_units_;
void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override; void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override;
}; };
} // namespace fastertransformer } // namespace fastertransformer
...@@ -68,12 +68,13 @@ AttentionType getAttentionType(size_t size_per_head, ...@@ -68,12 +68,13 @@ AttentionType getAttentionType(size_t size_per_head,
} }
// GPT and its variants // GPT and its variants
else { else {
// FMHA_ENABLE only affects gpt-style models (causal-mask) // FMHA_ENABLE only affects gpt-style models (causal-mask)
char * fused_qkv = std::getenv("FMHA_ENABLE"); char* fused_qkv = std::getenv("FMHA_ENABLE");
if (fused_qkv != nullptr && std::string(fused_qkv) == "ON") { if (fused_qkv != nullptr && std::string(fused_qkv) == "ON") {
if ((sm == kSM_70 || sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89) if ((sm == kSM_70 || sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)
&& (size_per_head == 32 || size_per_head == 40 || size_per_head == 64 || size_per_head == 80 && (size_per_head == 32 || size_per_head == 40 || size_per_head == 64 || size_per_head == 80
|| size_per_head == 128 || size_per_head == 144 || size_per_head == 160 || size_per_head == 256)) { || size_per_head == 128 || size_per_head == 144 || size_per_head == 160
|| size_per_head == 256)) {
return remove_padding ? AttentionType::FUSED_MHA : AttentionType::UNFUSED_PADDED_MHA; return remove_padding ? AttentionType::FUSED_MHA : AttentionType::UNFUSED_PADDED_MHA;
} }
} }
......
...@@ -13,4 +13,3 @@ ...@@ -13,4 +13,3 @@
# limitations under the License. # limitations under the License.
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
/* /*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" #include "src/fastertransformer/layers/attention_layers/AttentionWeight.h"
#include "src/fastertransformer/utils/ScaleList.h" #include "src/fastertransformer/utils/ScaleList.h"
namespace fastertransformer { namespace fastertransformer {
template<typename T1, typename T2> template<typename T1, typename T2>
struct AttentionFP8Weight: public AttentionWeight<T1, T2> { struct AttentionFP8Weight: public AttentionWeight<T1, T2> {
const float* qk_scale; const float* qk_scale;
const float* qk_scale_inv; const float* qk_scale_inv;
float* qk_h_scale; float* qk_h_scale;
float* qk_h_scale_inv; float* qk_h_scale_inv;
float* identity_scale; float* identity_scale;
float* identity_h_scale; float* identity_h_scale;
}; };
} // namespace fastertransformer } // namespace fastertransformer
/* /*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#pragma once #pragma once
#include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" #include "src/fastertransformer/layers/attention_layers/AttentionWeight.h"
#include "src/fastertransformer/utils/ScaleList.h" #include "src/fastertransformer/utils/ScaleList.h"
namespace fastertransformer { namespace fastertransformer {
template<typename T> template<typename T>
struct AttentionINT8Weight: AttentionWeight<T> { struct AttentionINT8Weight: AttentionWeight<T> {
ScaleList* scale_list_ptr; ScaleList* scale_list_ptr;
}; };
} // namespace fastertransformer } // namespace fastertransformer
...@@ -46,4 +46,4 @@ public: ...@@ -46,4 +46,4 @@ public:
} }
}; };
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -15,9 +15,9 @@ public: ...@@ -15,9 +15,9 @@ public:
pthread_barrier_init(&barrier_, nullptr, count); pthread_barrier_init(&barrier_, nullptr, count);
} }
Barrier(const Barrier&) = delete; Barrier(const Barrier&) = delete;
Barrier& operator=(const Barrier&) = delete; Barrier& operator=(const Barrier&) = delete;
Barrier(Barrier&&) noexcept = delete; Barrier(Barrier&&) noexcept = delete;
Barrier& operator=(Barrier&&) noexcept = delete; Barrier& operator=(Barrier&&) noexcept = delete;
void wait() void wait()
...@@ -34,4 +34,4 @@ private: ...@@ -34,4 +34,4 @@ private:
pthread_barrier_t barrier_{}; pthread_barrier_t barrier_{};
}; };
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 3.8) ...@@ -4,7 +4,7 @@ cmake_minimum_required(VERSION 3.8)
add_subdirectory(fused_multi_head_attention) add_subdirectory(fused_multi_head_attention)
add_library(Llama STATIC add_library(Llama STATIC
LlamaV2.cc LlamaV2.cc
LlamaBatch.cc LlamaBatch.cc
LlamaCacheManager.cc LlamaCacheManager.cc
......
...@@ -19,11 +19,11 @@ template<typename T> ...@@ -19,11 +19,11 @@ template<typename T>
void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs, void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs,
std::vector<std::shared_ptr<Request>>& infer_reqs) std::vector<std::shared_ptr<Request>>& infer_reqs)
{ {
std::unordered_map<uint64_t, int> occurance; std::unordered_map<uint64_t, int> occurrence;
auto count_occurance = [&occurance](const std::vector<std::shared_ptr<Request>>& rs) { auto count_occurrence = [&occurrence](const std::vector<std::shared_ptr<Request>>& rs) {
for (const auto& r : rs) { for (const auto& r : rs) {
++occurance[r->id]; ++occurrence[r->id];
} }
}; };
...@@ -33,13 +33,13 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r ...@@ -33,13 +33,13 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
req.reset(); req.reset();
}; };
auto handle_conflict_or_invalid = [this, &occurance, &invalidate](std::vector<std::shared_ptr<Request>>& rs, auto handle_conflict_or_invalid = [this, &occurrence, &invalidate](std::vector<std::shared_ptr<Request>>& rs,
const char* type) { const char* type) {
for (auto& r : rs) { for (auto& r : rs) {
if (r) { if (r) {
int ec = 0; int ec = 0;
if (occurance[r->id] != 1) { if (occurrence[r->id] != 1) {
ec = Request::kConflict; ec = Request::kConflict;
} }
else if (r->start_flag && r->stop_flag) { else if (r->start_flag && r->stop_flag) {
...@@ -66,8 +66,8 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r ...@@ -66,8 +66,8 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
rs.resize(count); rs.resize(count);
}; };
count_occurance(stop_reqs); count_occurrence(stop_reqs);
count_occurance(infer_reqs); count_occurrence(infer_reqs);
if (!stop_reqs.empty()) { if (!stop_reqs.empty()) {
handle_conflict_or_invalid(stop_reqs, "stop"); handle_conflict_or_invalid(stop_reqs, "stop");
...@@ -129,7 +129,7 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request ...@@ -129,7 +129,7 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request
ec = 0; ec = 0;
llama_->kv_cache_mgr_->erase(r->id); llama_->kv_cache_mgr_->erase(r->id);
} }
// clear output buffers (prevent leaking conversations) if request is successfull // clear output buffers (prevent leaking conversations) if request is successful
if (ec == 0) { if (ec == 0) {
auto& output_ids = r->outputs[rank_].at("output_ids"); auto& output_ids = r->outputs[rank_].at("output_ids");
auto& sequence_length = r->outputs[rank_].at("sequence_length"); auto& sequence_length = r->outputs[rank_].at("sequence_length");
...@@ -407,7 +407,7 @@ void LlamaBatch<T>::initializeGeneration() ...@@ -407,7 +407,7 @@ void LlamaBatch<T>::initializeGeneration()
check_cuda_error( check_cuda_error(
cudaMemcpyAsync(sequence_lengths_, context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); cudaMemcpyAsync(sequence_lengths_, context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
// `sequence_lengths_` will be increased by dynamic decode // `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has differnt semantic // note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed // - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet) // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size_, stream_); invokePlusScalar(sequence_lengths_, -1, batch_size_, stream_);
...@@ -1039,4 +1039,4 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end) ...@@ -1039,4 +1039,4 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
template class LlamaBatch<half>; template class LlamaBatch<half>;
template class LlamaBatch<float>; template class LlamaBatch<float>;
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
...@@ -122,7 +122,7 @@ private: ...@@ -122,7 +122,7 @@ private:
void* topk_curandstate_buf_{}; void* topk_curandstate_buf_{};
void* topp_curandstate_buf_{}; void* topp_curandstate_buf_{};
// hard limits for persistant buffers // hard limits for persistent buffers
static constexpr int kMaxStopBadWordsLen = 32; static constexpr int kMaxStopBadWordsLen = 32;
using CachedSeq = LlamaCacheManager::Sequence; using CachedSeq = LlamaCacheManager::Sequence;
...@@ -150,4 +150,4 @@ private: ...@@ -150,4 +150,4 @@ private:
IAllocator* allocator_{}; IAllocator* allocator_{};
}; };
} // namespace fastertransformer } // namespace fastertransformer
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment