Commit 5cf21c3b authored by wenjh's avatar wenjh
Browse files

Add bias fwd/bwd at group gemm


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 86d5cd03
......@@ -1030,14 +1030,14 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
n.push_back(B0);
}
}
bool use_bias = biasTensor[0]->data.dptr != nullptr? true: false;
Tensor *wspace = convertNVTETensorCheck(workspace[0]);
if ((biasTensor[0]->data.dptr != nullptr) || (outputGelu[0]->data.dptr != nullptr)) {
NVTE_ERROR("MOE nvte_grouped_gemm not surpport bias or gelu.");
if (outputGelu[0]->data.dptr != nullptr) {
NVTE_ERROR("MOE nvte_grouped_gemm not surpport gelu.");
}
hipblaslt_goupedgemm(inputA, inputB, outputD, m, n, k, b,
hipblaslt_goupedgemm(inputA, inputB, outputD, biasTensor, use_bias, grad, m, n, k, b,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
wspace->data.dptr, wspace->data.shape[0],
......
......@@ -362,9 +362,9 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
return val;
}
template <typename InputType>
template <typename InputType, typename OutputType>
__launch_bounds__(1024) __global__
void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) {
void bias_gradient_kernel_v2(OutputType* dst, const InputType* src, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f;
......@@ -380,7 +380,7 @@ __launch_bounds__(1024) __global__
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<float>(sum);
dst[j] = static_cast<OutputType>(sum);
}
}
}
......@@ -409,8 +409,8 @@ __launch_bounds__(1024) __global__
}
}
template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
template <typename Tin, typename Tout>
void bias_gradient_kernelLauncher(const Tin* in, Tout* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
......@@ -418,13 +418,13 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n;
if (!stream_order_alloc) {
NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float)));
NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(Tout)));
} else {
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream));
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
}
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int B = (n - 1) / kColwiseReduceTileSize + 1;
bias_gradient_kernel_v2<Tin>
bias_gradient_kernel_v2<Tin, Tout>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
}
......@@ -1426,7 +1426,7 @@ struct HipBlasltUserArgsCacheManager {
};
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
std::vector<Tensor*>& outputD, std::vector<const Tensor*>& bias, bool use_bias, bool grad, std::vector<int64_t>& m,
std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
......@@ -1466,6 +1466,13 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{hipblaslt_ext::GemmEpilogue()};
if(use_bias && !grad)
{
const hipDataType bias_type = get_hipblaslt_dtype(bias[0]->data.dtype);
NVTE_CHECK(bias_type == HIP_R_32F || bias_type == HIP_R_16BF);
epilogue[0].mode = HIPBLASLT_EPILOGUE_BIAS;
epilogue[0].bias_data_type = bias_type;
}
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for (int i = 0; i < m.size(); i++) {
assert(m[i] != 0);
......@@ -1476,6 +1483,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr;
inputs[i].d = outputD[i]->data.dptr;
inputs[i].bias = bias[i]->data.dptr;
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
}
......@@ -1511,6 +1519,25 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
NVTE_CHECK_HIPBLASLT(groupedgemm.run(device_args, stream));
device_user_args.setStream(stream);
NVTE_CHECK_CUDA(hipEventRecord(device_event, stream));
if(use_bias && grad)
{
DType input_type = inputB[0]->data.dtype;
DType bias_type = bias[0]->data.dtype;
NVTE_CHECK(bias_type == DType::kFloat32 || bias_type == DType::kFloat16 || bias_type == DType::kBFloat16);
for (int i = 0; i < m.size(); ++i) {
void* input_ptr = inputB[i]->data.dptr;
void* bias_ptr = bias[i]->data.dptr;
batch_size = k[i];
output_dim = n[i];
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input_dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
bias_type, OType,
detail::bias_gradient_kernelLauncher<IType, OType>(
reinterpret_cast<const IType*>(input_ptr), reinterpret_cast<OType*>(bias_ptr), batch_size,
output_dim, true, stream);));
}
}
}
#endif //USE_HIPBLASLT
......@@ -1737,7 +1764,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
detail::bias_gradient_kernelLauncher<OType>(
detail::bias_gradient_kernelLauncher<OType, float>(
reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
input_dim, stream_order_alloc, stream););
......@@ -1807,7 +1834,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input_dtype, IType,
detail::bias_gradient_kernelLauncher<IType>(
detail::bias_gradient_kernelLauncher<IType, float>(
reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size,
output_dim, stream_order_alloc, stream););
if (bias_type != rocblas_datatype_f32_r) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment