Commit 4a013bd5 authored by yuguo's avatar yuguo
Browse files

[DCU] fix channelwise train oom bug

parent ddfbdaf4
......@@ -1479,8 +1479,8 @@ private:
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
// static userArgsManager UAManager;
// static d_userArgsManager d_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb,
......@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id;
hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// int device_id;
// hipGetDevice(&device_id);
// hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
}
// Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
userArgs,
m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice, stream));
// NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
// userArgs,
// m.size() * sizeof(hipblaslt_ext::UserArguments),
// hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
......
......@@ -190,6 +190,7 @@ def general_gemm(
if layout == "TN":
assert out_dtype is torch.bfloat16
out_shape = B._data.shape[:-1] + (A._data.shape[0], )
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -212,10 +213,11 @@ def general_gemm(
use_split_accumulator,
)[0]
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
return y, None, None, None
return y.view(out_shape), None, None, None
elif layout == "NN":
assert out_dtype is torch.bfloat16
dx_shape = B._data.shape[:-1] + (A._data.shape[-1], )
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -238,7 +240,7 @@ def general_gemm(
use_split_accumulator,
)[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
return dx, None, None, None
return dx.view(dx_shape), None, None, None
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
......@@ -475,7 +477,8 @@ def general_grouped_gemm(
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize_transB(scales_x_list[i], scales_w_list[i], y_int32[i])
return out.view(-1, out[0].size(-1)), bias, gelu_input
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NN":
assert out_dtype is torch.bfloat16
......@@ -522,6 +525,7 @@ def general_grouped_gemm(
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize(scales_dout_list[i], scales_w_list[i], dx_int32[i])
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NT":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment