"vscode:/vscode.git/clone" did not exist on "09ff7f106a1e5ad77f6c2941382d5ba4fd5a0879"
Commit cb2fe806 authored by wenjh's avatar wenjh
Browse files

Add bias fwd/bwd at group gemm


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent f796eb80
......@@ -348,7 +348,7 @@ else()
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
util/cast.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
......
......@@ -705,6 +705,7 @@ struct TypeInfo {
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
......@@ -712,10 +713,6 @@ struct TypeInfo {
case DType::kFloat4E2M1: { \
NVTE_ERROR("FP4 type not instantiated for input."); \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
......@@ -735,14 +732,11 @@ struct TypeInfo {
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
......
......@@ -1425,13 +1425,14 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
n.push_back(B0);
}
}
bool use_bias = biasTensor[0]->data.dptr != nullptr? true: false;
Tensor *wspace = convertNVTETensorCheck(workspace[0]);
if ((biasTensor[0]->data.dptr != nullptr) || (outputGelu[0]->data.dptr != nullptr)) {
NVTE_ERROR("MOE nvte_grouped_gemm not surpport bias or gelu.");
if (outputGelu[0]->data.dptr != nullptr) {
NVTE_ERROR("MOE nvte_grouped_gemm not surpport gelu.");
}
hipblaslt_groupedgemm(inputA, inputB, outputD, m, n, k, b,
hipblaslt_groupedgemm(inputA, inputB, outputD, biasTensor, use_bias, grad, m, n, k, b,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
wspace->data.dptr, wspace->data.shape[0],
......
......@@ -362,9 +362,9 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
return val;
}
template <typename InputType>
template <typename InputType, typename OutputType>
__launch_bounds__(1024) __global__
void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) {
void bias_gradient_kernel_v2(OutputType* dst, const InputType* src, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f;
......@@ -380,7 +380,7 @@ __launch_bounds__(1024) __global__
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<float>(sum);
dst[j] = static_cast<OutputType>(sum);
}
}
}
......@@ -409,8 +409,8 @@ __launch_bounds__(1024) __global__
}
}
template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
template <typename Tin, typename Tout>
void bias_gradient_kernelLauncher(const Tin* in, Tout* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
......@@ -418,13 +418,13 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n;
if (!stream_order_alloc) {
NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float)));
NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(Tout)));
} else {
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream));
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
}
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int B = (n - 1) / kColwiseReduceTileSize + 1;
bias_gradient_kernel_v2<Tin>
bias_gradient_kernel_v2<Tin, Tout>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
}
......@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
}
static void DestroyHipBlasLtHandle(hipblasLtHandle_t handle) {
if(handle != nullptr)
if(handle != nullptr) {
NVTE_CHECK_HIPBLASLT(hipblasLtDestroy(handle));
}
}
......@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
{
HipBlasltUserArgsCache() {}
HipBlasltUserArgsCache(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltUserArgsCache& operator=(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& getBuffer(hipStream_t stream, size_t size, bool host)
{
std::unordered_map<size_t, HipBlasltUserArgsBuffer>& buffers = host ? host_buffers_: device_buffers_;
......@@ -1425,9 +1425,8 @@ struct HipBlasltUserArgsCacheManager {
std::vector<HipBlasltUserArgsCache> caches_;
};
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
std::vector<Tensor*>& outputD, std::vector<const Tensor*>& bias, bool use_bias, bool grad, std::vector<int64_t>& m,
std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
......@@ -1467,6 +1466,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{hipblaslt_ext::GemmEpilogue()};
if(use_bias && !grad)
{
const hipDataType bias_type = get_hipblaslt_dtype(bias[0]->data.dtype);
NVTE_CHECK(bias_type == HIP_R_32F || bias_type == HIP_R_16BF);
epilogue[0].mode = HIPBLASLT_EPILOGUE_BIAS;
epilogue[0].bias_data_type = bias_type;
}
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for (int i = 0; i < m.size(); i++) {
assert(m[i] != 0);
......@@ -1477,6 +1483,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr;
inputs[i].d = outputD[i]->data.dptr;
inputs[i].bias = bias[i]->data.dptr;
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
}
......@@ -1512,6 +1519,26 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
NVTE_CHECK_HIPBLASLT(groupedgemm.run(device_args, stream));
device_user_args.setStream(stream);
NVTE_CHECK_CUDA(hipEventRecord(device_event, stream));
if(use_bias && grad)
{
DType input_type = inputB[0]->data.dtype;
DType bias_type = bias[0]->data.dtype;
NVTE_CHECK(bias_type == DType::kFloat32 || bias_type == DType::kFloat16 || bias_type == DType::kBFloat16);
for (int i = 0; i < m.size(); ++i) {
void* input_ptr = inputB[i]->data.dptr;
void* bias_ptr = bias[i]->data.dptr;
int batch_size = static_cast<int>(k[i]);
int output_dim = static_cast<int>(n[i]);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
bias_type, OType,
detail::bias_gradient_kernelLauncher<IType, OType>(
reinterpret_cast<const IType*>(input_ptr), reinterpret_cast<OType*>(bias_ptr), batch_size,
output_dim, true, stream);));
}
}
}
#endif //USE_HIPBLASLT
......@@ -1738,7 +1765,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
detail::bias_gradient_kernelLauncher<OType>(
detail::bias_gradient_kernelLauncher<OType, float>(
reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
input_dim, stream_order_alloc, stream););
......@@ -1808,7 +1835,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input_dtype, IType,
detail::bias_gradient_kernelLauncher<IType>(
detail::bias_gradient_kernelLauncher<IType, float>(
reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size,
output_dim, stream_order_alloc, stream););
if (bias_type != rocblas_datatype_f32_r) {
......
......@@ -71,6 +71,13 @@ constexpr bool is_supported_arch() {
}
}
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
#if CUDA_VERSION < 12090
#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL)
#define __CUDA_ARCH_SPECIFIC__ 900
......@@ -90,6 +97,7 @@ constexpr bool is_supported_arch() {
#endif
#endif
#ifdef __CUDA_ARCH__
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__;
#else
......@@ -246,14 +254,6 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1
: __int_as_float((254 - biased_exp)
......@@ -265,6 +265,9 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
}
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_DEVICE_ERROR("float_to_e8m0 is not supported on rocm platform.");
#else
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint16_t out;
......@@ -296,6 +299,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
}
return exponent;
}
#endif
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
......@@ -407,6 +411,8 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
template <typename T>
struct alignas(2 * sizeof(T)) FPx2 {
T x;
......@@ -834,6 +840,8 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
}
#endif
} // namespace ptx
namespace {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment