"...resnet50_tensorflow.git" did not exist on "72775fa3cba5f09bfec185ed76a7a3f877b90e31"
Commit 5eb5e316 authored by danyao12's avatar danyao12
Browse files

attn bwd pt1 add dropout function

parent 4f6f7f8e
...@@ -10,6 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_ ...@@ -10,6 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp) add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp) add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
......
...@@ -864,6 +864,16 @@ struct BlockwiseGemmXdlops_v2 ...@@ -864,6 +864,16 @@ struct BlockwiseGemmXdlops_v2
{ {
} }
__device__ void SetABlockStartWindow(Tuple4 a_origin = CalculateAThreadOriginDataIndex())
{
a_thread_copy_.SetSrcCoord(a_origin);
}
__device__ void SetBBlockStartWindow(Tuple4 b_origin = CalculateBThreadOriginDataIndex())
{
b_thread_copy_.SetSrcCoord(b_origin);
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment