Commit 2183406b authored by rtmadduri's avatar rtmadduri
Browse files

LWPCK-2429: Device grouped GEMM uses Async Memcpy

Resolving merge conflicts
parent e7b62864
#pragma once #pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
hipGetErrorString( hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
......
...@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float time{0.f}; float time{0.f};
hip_check_error( hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs, hipMemcpyAsync(dev_gemm_kargs,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
auto preprocess = [&]() { auto preprocess = [&]() {
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(
......
...@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_host_kernel_args) const const void* p_host_kernel_args) const
{ {
arg.p_dev_gemm_args_ = p_dev_kernel_args; arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args, hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args, p_host_kernel_args,
GetDeviceKernelArgSize(&arg), GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
} }
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
} }
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, hipGetErrorString(
arg.gemm_desc_kernel_arg_.data(), hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.size() * arg.gemm_desc_kernel_arg_.data(),
sizeof(GemmBiasTransKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -302,6 +302,13 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -302,6 +302,13 @@ bool profile_grouped_gemm_impl(int do_verification,
rtol, rtol,
atol); atol);
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
rtol,
atol);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",") LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment