You need to sign in or sign up before continuing.
Unverified Commit e2ebc8e7 authored by who who who's avatar who who who Committed by GitHub
Browse files

replace hipMemcpy with hipMemcpyWithStream (#734)

parent 9eae73df
...@@ -611,10 +611,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -611,10 +611,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -652,11 +652,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -652,11 +652,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
} }
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.contraction_multi_d_kernel_args_.data(), arg.contraction_multi_d_kernel_args_.data(),
arg.contraction_multi_d_kernel_args_.size() * arg.contraction_multi_d_kernel_args_.size() *
sizeof(ContractionMultiDKernelArg), sizeof(ContractionMultiDKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -597,10 +597,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -597,10 +597,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
......
...@@ -549,10 +549,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -549,10 +549,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
hipGetErrorString( hipGetErrorString(
hipMemcpy(arg.p_workspace_, hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -406,10 +406,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -406,10 +406,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
} }
hip_check_error(hipMemcpy(arg.p_workspace_, hip_check_error(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment