Commit 374b85bd authored by yuguo's avatar yuguo
Browse files
parents 1b971e27 11864d3d
......@@ -384,6 +384,30 @@ __launch_bounds__(1024) __global__
}
}
template <typename OutputType>
__launch_bounds__(1024) __global__
void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f;
float tensorwise_scale = scale[0];
if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale;
}
}
g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<OutputType>(sum);
}
}
}
template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) {
......@@ -403,6 +427,19 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
}
template <typename Tout>
void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n;
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
int B = (n - 1) / kColwiseReduceTileSize + 1;
tensorwise_int8_bias_gradient_kernel<Tout>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n);
}
} // namespace detail
transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
......@@ -962,6 +999,20 @@ static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
}
}
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
using namespace transformer_engine;
switch (t) {
case HIP_R_16F:
return DType::kFloat16;
case HIP_R_32F:
return DType::kFloat32;
case HIP_R_16BF:
return DType::kBFloat16;
default:
NVTE_ERROR("Invalid type");
}
}
void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
......@@ -1090,9 +1141,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float,
sizeof(void*)));
if (bias) {
NVTE_CHECK(false, "tensorwise_int8 not surpport bias!");
}
}
if (bias && gelu) {
......@@ -1109,14 +1157,34 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) {
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
if (tensorwise_int8) {
if (grad) {
int batch_size = k;
int output_dim = n;
DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
te_bias_dtype, BType,·
detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>(
reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size,
output_dim, stream););
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
epilogue = HIPBLASLT_EPILOGUE_BIAS;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
}
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
} else if (gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU;
......
......@@ -185,7 +185,6 @@ def general_gemm(
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
......@@ -473,13 +472,13 @@ def general_grouped_gemm(
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list, scales_x_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
......@@ -501,7 +500,7 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias[0],
bias_dtype,
gelu,
gelu_input[0],
......@@ -514,7 +513,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
if layout == "NN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
......@@ -536,7 +535,7 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias[0],
bias_dtype,
gelu,
gelu_input[0],
......@@ -549,7 +548,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
qx_data_list, scales_x_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
......@@ -572,7 +571,7 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias[0],
bias_dtype,
gelu,
gelu_input[0],
......@@ -591,10 +590,10 @@ def general_grouped_gemm(
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list = []
w_data_list = []
scales_x_list = []
......@@ -642,7 +641,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list = []
w_data_list = []
scales_dout_list = []
......@@ -690,7 +689,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
......@@ -730,7 +729,7 @@ def general_grouped_gemm(
use_split_accumulator,
)[0]
if out_dtype is torch.bfloat16:
if TE_DType_To_Torch[out_dtype] is torch.bfloat16:
if accumulate:
for i in num_gemms:
channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
......
......@@ -49,8 +49,10 @@ std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool tran
const size_t B1 = B_shape.data[B_shape.ndim - 1];
// Check matrix dims
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")");
if (transa || transb) {
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")");
}
// Construct output dims
std::vector<size_t> ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment