Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
11eed39f
Commit
11eed39f
authored
Jan 15, 2023
by
guangzlu
Browse files
added dropou scale into device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
parent
2ac0eefd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
0 deletions
+8
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+8
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
View file @
11eed39f
...
...
@@ -27,6 +27,7 @@ template <typename GridwiseGemm,
typename
FloatAB
,
typename
FloatC
,
typename
FloatLSE
,
typename
GemmAccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
...
...
@@ -68,6 +69,7 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -110,6 +112,7 @@ __global__ void
block_2_ctile_map
,
c0_matrix_mask
,
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
);
#else
ignore
=
p_a_grid
;
...
...
@@ -549,6 +552,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
}
void
Print
()
const
...
...
@@ -612,6 +617,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
p_dropout_
;
ushort
p_dropout_in_16bits_
;
GemmAccDataType
p_dropout_rescale_
;
unsigned
long
long
seed_
;
bool
is_dropout_
;
};
...
...
@@ -643,6 +649,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
LSEDataType
,
GemmAccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
...
...
@@ -684,6 +691,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
);
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment