Commit 6c2d60d3 authored by Xinya Zhang's avatar Xinya Zhang
Browse files

hipMemcpy -> hipMemcpyWithStream

parent 7e71583f
...@@ -651,11 +651,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -651,11 +651,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
} }
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.contraction_multi_d_kernel_args_.data(), arg.contraction_multi_d_kernel_args_.data(),
arg.contraction_multi_d_kernel_args_.size() * arg.contraction_multi_d_kernel_args_.size() *
sizeof(ContractionMultiDKernelArg), sizeof(ContractionMultiDKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -610,10 +610,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -610,10 +610,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -543,10 +543,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -543,10 +543,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
hipGetErrorString( hipGetErrorString(
hipMemcpy(arg.p_workspace_, hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -934,10 +934,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -934,10 +934,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// some_has_main_k_block_loop |= y; // some_has_main_k_block_loop |= y;
// } // }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
...@@ -954,6 +955,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -954,6 +955,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
has_main_k_block_loop_, has_main_k_block_loop_,
Deterministic>; Deterministic>;
std::cerr << "Calling kernel kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1 LINE: " << __LINE__ << " arg.p_workspace_ = " << arg.p_workspace_ << std::endl;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
......
...@@ -941,10 +941,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -941,10 +941,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -955,10 +955,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -955,10 +955,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// some_has_main_k_block_loop |= y; // some_has_main_k_block_loop |= y;
// } // }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -962,10 +962,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -962,10 +962,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -804,10 +804,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -804,10 +804,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -826,10 +826,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -826,10 +826,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment