"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "20044cab7aa6e884e13460506b0e0b6a12722b5d"
Commit d6c32078 authored by yuguo's avatar yuguo
Browse files

[DCU] add TORCH_COMM_CU_NUMS and fix

parent 8eff19c9
......@@ -588,8 +588,43 @@ static cudaEvent_t cublas_event[num_streams];
// Warning: only call once per device!
static void init_streams_and_events() {
int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0xfffffff0;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0xffffff00;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0xffff0000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0x00000000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
for (int i = 0; i < num_streams; i++) {
#ifdef __HIP_PLATFORM_AMD__
if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
} else {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
}
#else
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
#endif
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
}
}
......@@ -601,8 +636,43 @@ static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];
// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0xfffffff0;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0xffffff00;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0xffff0000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0x00000000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
for (int i = 0; i < num_batchgemm_streams; i++) {
#ifdef __HIP_PLATFORM_AMD__
if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams_batchgemm[i], cuMaskSize, cuMask));
} else {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
}
#else
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
#endif
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
}
}
......
......@@ -201,7 +201,6 @@ def general_grouped_gemm(
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert bias is None, "Bias not supported with int8 simulation groupgemm."
if layout == "TN":
qx_data = [
......@@ -213,11 +212,14 @@ def general_grouped_gemm(
ref_scales_x = [b._rowwise_scale_inv for b in B]
ref_scales_w = [a._rowwise_scale_inv for a in A]
y, _ = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
out[0], _ = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
output_dtype=out_dtype
)
return y, None, None
return out, None, None
elif layout == "NN":
qdout_data = [
......@@ -229,11 +231,14 @@ def general_grouped_gemm(
ref_scales_dout = [b._rowwise_scale_inv for b in B]
ref_scales_w = [a._columnwise_scale_inv for a in A]
y, _ = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
out[0], _ = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
output_dtype=out_dtype
)
return y, None, None
return out, None, None
elif layout == "NT":
qdout_data = [
......
......@@ -446,7 +446,7 @@ def _w8a8_block_int8_matmul_batched(
def w8a8_block_int8_matmul_batched(
A_list, B_list, As_list, Bs_list,
A_list, B_list, As_list, Bs_list, C,
block_size, output_dtype=torch.float16, best_config=None
):
A = torch.stack(A_list).contiguous() # [B, M, K]
......@@ -460,8 +460,7 @@ def w8a8_block_int8_matmul_batched(
batch, N, K = B.shape
block_n, block_k = block_size
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
assert C.size(-1) == N
config = {
"BLOCK_SIZE_M": 64,
......@@ -507,7 +506,7 @@ def w8a8_block_int8_matmul_batched(
**config,
)
return C
return C.view(-1, C.size(-1))
def apply_w8a8_block_int8_linear_batched_helper(m: int,
n: int,
......
......@@ -532,7 +532,7 @@ def w8a8_block_int8_matmul_wgrad_batched(
**config,
)
return C
return [C[i] for i in range(C.size(0))]
def apply_w8a8_block_int8_linear_batched_helper(m: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment