Commit 5eb5e316 authored by danyao12's avatar danyao12
Browse files

attn bwd pt1 add dropout function

parent 4f6f7f8e
......@@ -10,6 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
......
......@@ -864,6 +864,16 @@ struct BlockwiseGemmXdlops_v2
{
}
__device__ void SetABlockStartWindow(Tuple4 a_origin = CalculateAThreadOriginDataIndex())
{
a_thread_copy_.SetSrcCoord(a_origin);
}
__device__ void SetBBlockStartWindow(Tuple4 b_origin = CalculateBThreadOriginDataIndex())
{
b_thread_copy_.SetSrcCoord(b_origin);
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment