Unverified Commit 9488f1c9 authored by rtmadduri's avatar rtmadduri Committed by GitHub
Browse files

LWPCK-2429: Device grouped GEMM uses Async Memcpy (#1695)



* LWPCK-2429: Device grouped GEMM uses Async Memcpy
Resolving merge conflicts

* reverting changes to profile_grouped_gemm

* revert date change

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
parent 44828b7c
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
......
......@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float time{0.f};
hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(dev_gemm_kargs,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
......
......@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_host_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
......
......@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() *
sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipGetErrorString(
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment