Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a8352bc
Commit
7a8352bc
authored
Sep 05, 2023
by
danyao12
Browse files
fix split kernels dropout related bugs
parent
edbb3439
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
15 deletions
+37
-15
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+7
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+7
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+8
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+8
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
...gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
+7
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
7a8352bc
...
@@ -46,7 +46,8 @@ __global__ void
...
@@ -46,7 +46,8 @@ __global__ void
const
DGridDescriptor_M
d_grid_desc_m
,
const
DGridDescriptor_M
d_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
float
p_drop
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -66,7 +67,8 @@ __global__ void
...
@@ -66,7 +67,8 @@ __global__ void
p_d_grid
+
d_batch_offset
,
p_d_grid
+
d_batch_offset
,
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_m
,
d_grid_desc_m
,
block_2_ctile_map
);
block_2_ctile_map
,
p_drop
);
#else
#else
ignore
=
p_y_grid
;
ignore
=
p_y_grid
;
...
@@ -77,6 +79,7 @@ __global__ void
...
@@ -77,6 +79,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
p_drop
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -1131,7 +1134,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1131,7 +1134,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg
.
d_grid_desc_m_
,
arg
.
d_grid_desc_m_
,
arg
.
d_block_2_ctile_map_
,
arg
.
d_block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
);
arg
.
compute_base_ptr_of_batch_
,
arg
.
p_drop_
);
};
};
ave_time
=
launch_kernel
();
ave_time
=
launch_kernel
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
7a8352bc
...
@@ -46,7 +46,8 @@ __global__ void
...
@@ -46,7 +46,8 @@ __global__ void
const
DGridDescriptor_M
d_grid_desc_m
,
const
DGridDescriptor_M
d_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
float
p_drop
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -66,7 +67,8 @@ __global__ void
...
@@ -66,7 +67,8 @@ __global__ void
p_d_grid
+
d_batch_offset
,
p_d_grid
+
d_batch_offset
,
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_m
,
d_grid_desc_m
,
block_2_ctile_map
);
block_2_ctile_map
,
p_drop
);
#else
#else
ignore
=
p_y_grid
;
ignore
=
p_y_grid
;
...
@@ -77,6 +79,7 @@ __global__ void
...
@@ -77,6 +79,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
p_drop
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -1143,7 +1146,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1143,7 +1146,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg
.
d_grid_desc_m_
,
arg
.
d_grid_desc_m_
,
arg
.
d_block_2_ctile_map_
,
arg
.
d_block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
);
arg
.
compute_base_ptr_of_batch_
,
arg
.
p_drop_
);
};
};
ave_time
=
launch_kernel
();
ave_time
=
launch_kernel
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
7a8352bc
...
@@ -33,7 +33,9 @@ __global__ void
...
@@ -33,7 +33,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_ydotygrad_v1
(
kernel_grouped_multihead_attention_backward_ydotygrad_v1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
float
p_dropout
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -74,10 +76,12 @@ __global__ void
...
@@ -74,10 +76,12 @@ __global__ void
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
d_y_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_y_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m_
,
arg_ptr
[
group_id
].
d_grid_desc_m_
,
arg_ptr
[
group_id
].
d_block_2_ctile_map_
);
arg_ptr
[
group_id
].
d_block_2_ctile_map_
,
p_dropout
);
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
ignore
=
p_dropout
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -1175,7 +1179,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1175,7 +1179,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
);
arg
.
group_count_
,
arg
.
p_dropout_
);
};
};
ave_time
=
launch_kernel
();
ave_time
=
launch_kernel
();
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
7a8352bc
...
@@ -32,7 +32,9 @@ __global__ void
...
@@ -32,7 +32,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_ydotygrad_v2
(
kernel_grouped_multihead_attention_backward_ydotygrad_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
float
p_dropout
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -73,10 +75,12 @@ __global__ void
...
@@ -73,10 +75,12 @@ __global__ void
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
d_y_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_y_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m_
,
arg_ptr
[
group_id
].
d_grid_desc_m_
,
arg_ptr
[
group_id
].
d_block_2_ctile_map_
);
arg_ptr
[
group_id
].
d_block_2_ctile_map_
,
p_dropout
);
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
ignore
=
p_dropout
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -1244,7 +1248,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1244,7 +1248,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
);
arg
.
group_count_
,
arg
.
p_dropout_
);
};
};
ave_time
=
launch_kernel
();
ave_time
=
launch_kernel
();
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
View file @
7a8352bc
...
@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
float
p_drop
)
{
{
const
FloatD
p_dropout
=
type_convert
<
FloatD
>
(
1.0
f
-
p_drop
);
const
tensor_operation
::
element_wise
::
Scale
scale_p_dropout
(
p_dropout
);
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_y_grid
,
y_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD
,
FloatD
,
decltype
(
d_thread_desc_mblock_m1
),
decltype
(
d_thread_desc_mblock_m1
),
decltype
(
d_grid_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Scale
,
Sequence
<
1
,
1
>
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
...
@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx_m
,
// mblock
make_multi_index
(
block_work_idx_m
,
// mblock
get_thread_local_1d_id
()),
// mperblock
get_thread_local_1d_id
()),
// mperblock
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
};
scale_p_dropout
};
// copy from VGPR to Global
// copy from VGPR to Global
d_thread_copy_vgpr_to_global
.
Run
(
d_thread_desc_mblock_m1
,
d_thread_copy_vgpr_to_global
.
Run
(
d_thread_desc_mblock_m1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment