"git@developer.sourcefind.cn:tsoc/spack-configs.git" did not exist on "b755db38667e529690c629a86926471c2e121455"
Commit 4a013bd5 authored by yuguo's avatar yuguo
Browse files

[DCU] fix channelwise train oom bug

parent ddfbdaf4
...@@ -1479,8 +1479,8 @@ private: ...@@ -1479,8 +1479,8 @@ private:
}; };
// Define a static userArgs manager // Define a static userArgs manager
static userArgsManager UAManager; // static userArgsManager UAManager;
static d_userArgsManager d_UAManager; // static d_userArgsManager d_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD, void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb, std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb,
...@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id; // int device_id;
hipGetDevice(&device_id); // hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size()); // hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size()); // hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs; // hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
} }
// Get the default values from the grouepdgemm object // Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); // groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory // Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs; // hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream)); // NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs, // NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
userArgs, // userArgs,
m.size() * sizeof(hipblaslt_ext::UserArguments), // m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice, stream)); // hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes // Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace)); // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream)); // NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs)); // NVTE_CHECK_CUDA(hipFree(userArgs));
......
...@@ -190,6 +190,7 @@ def general_gemm( ...@@ -190,6 +190,7 @@ def general_gemm(
if layout == "TN": if layout == "TN":
assert out_dtype is torch.bfloat16 assert out_dtype is torch.bfloat16
out_shape = B._data.shape[:-1] + (A._data.shape[0], )
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False) x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False) w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
...@@ -212,10 +213,11 @@ def general_gemm( ...@@ -212,10 +213,11 @@ def general_gemm(
use_split_accumulator, use_split_accumulator,
)[0] )[0]
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32) y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
return y, None, None, None return y.view(out_shape), None, None, None
elif layout == "NN": elif layout == "NN":
assert out_dtype is torch.bfloat16 assert out_dtype is torch.bfloat16
dx_shape = B._data.shape[:-1] + (A._data.shape[-1], )
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False) dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False) w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
...@@ -238,7 +240,7 @@ def general_gemm( ...@@ -238,7 +240,7 @@ def general_gemm(
use_split_accumulator, use_split_accumulator,
)[0] )[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32) dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
return dx, None, None, None return dx.view(dx_shape), None, None, None
elif layout == "NT": elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32 assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
...@@ -475,7 +477,8 @@ def general_grouped_gemm( ...@@ -475,7 +477,8 @@ def general_grouped_gemm(
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1)) out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms): for i in range(num_gemms):
out[0][i] = channelwise_dequantize_transB(scales_x_list[i], scales_w_list[i], y_int32[i]) out[0][i] = channelwise_dequantize_transB(scales_x_list[i], scales_w_list[i], y_int32[i])
return out.view(-1, out[0].size(-1)), bias, gelu_input out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NN": elif layout == "NN":
assert out_dtype is torch.bfloat16 assert out_dtype is torch.bfloat16
...@@ -522,6 +525,7 @@ def general_grouped_gemm( ...@@ -522,6 +525,7 @@ def general_grouped_gemm(
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1)) out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms): for i in range(num_gemms):
out[0][i] = channelwise_dequantize(scales_dout_list[i], scales_w_list[i], dx_int32[i]) out[0][i] = channelwise_dequantize(scales_dout_list[i], scales_w_list[i], dx_int32[i])
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input return out, bias, gelu_input
elif layout == "NT": elif layout == "NT":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment