"conda/vscode:/vscode.git/clone" did not exist on "908779b26d3daa03fb258443e89e96f4b81f90d7"
Commit 6c2d60d3 authored by Xinya Zhang's avatar Xinya Zhang
Browse files

hipMemcpy -> hipMemcpyWithStream

parent 7e71583f
......@@ -651,11 +651,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
}
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.contraction_multi_d_kernel_args_.data(),
arg.contraction_multi_d_kernel_args_.size() *
sizeof(ContractionMultiDKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -610,10 +610,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -543,10 +543,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
hipGetErrorString(
hipMemcpy(arg.p_workspace_,
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -934,10 +934,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// some_has_main_k_block_loop |= y;
// }
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......@@ -954,6 +955,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
has_main_k_block_loop_,
Deterministic>;
std::cerr << "Calling kernel kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1 LINE: " << __LINE__ << " arg.p_workspace_ = " << arg.p_workspace_ << std::endl;
return launch_and_time_kernel(
stream_config,
kernel,
......
......@@ -941,10 +941,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -955,10 +955,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// some_has_main_k_block_loop |= y;
// }
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -962,10 +962,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -804,10 +804,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -826,10 +826,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment