"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "2a7cb008acbf7a1942a6270b014bdf86886ed6f1"
Commit 11864d3d authored by yuguo's avatar yuguo
Browse files

[DCU] tensorwise int8 gemm surpport bias

parent 32edae18
......@@ -384,6 +384,30 @@ __launch_bounds__(1024) __global__
}
}
template <typename OutputType>
__launch_bounds__(1024) __global__
void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f;
float tensorwise_scale = scale[0];
if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale;
}
}
g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<OutputType>(sum);
}
}
}
template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) {
......@@ -403,6 +427,19 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
}
template <typename Tout>
void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n;
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
int B = (n - 1) / kColwiseReduceTileSize + 1;
tensorwise_int8_bias_gradient_kernel<Tout>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n);
}
} // namespace detail
transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
......@@ -962,6 +999,20 @@ static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
}
}
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
using namespace transformer_engine;
switch (t) {
case HIP_R_16F:
return DType::kFloat16;
case HIP_R_32F:
return DType::kFloat32;
case HIP_R_16BF:
return DType::kBFloat16;
default:
NVTE_ERROR("Invalid type");
}
}
void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
......@@ -1090,9 +1141,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float,
sizeof(void*)));
if (bias) {
NVTE_CHECK(false, "tensorwise_int8 not surpport bias!");
}
}
if (bias && gelu) {
......@@ -1109,14 +1157,34 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) {
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
if (tensorwise_int8) {
if (grad) {
int batch_size = k;
int output_dim = n;
DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
te_bias_dtype, BType,·
detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>(
reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size,
output_dim, stream););
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
epilogue = HIPBLASLT_EPILOGUE_BIAS;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
}
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
} else if (gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU;
......
......@@ -185,7 +185,6 @@ def general_gemm(
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment