Unverified Commit 4db6a534 authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

Merge branch 'attn-bwd-develop' into fwd-drop-verify2

parents b9cb659d 67f39ad1
......@@ -10,7 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -10,6 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = true;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
......
......@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose
......
......@@ -864,6 +864,16 @@ struct BlockwiseGemmXdlops_v2
{
}
__device__ void SetABlockStartWindow(Tuple4 a_origin = CalculateAThreadOriginDataIndex())
{
a_thread_copy_.SetSrcCoord(a_origin);
}
__device__ void SetBBlockStartWindow(Tuple4 b_origin = CalculateBThreadOriginDataIndex())
{
b_thread_copy_.SetSrcCoord(b_origin);
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
......
......@@ -602,6 +602,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{
is_lse_storing_ = false;
}
}
void Print() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment