Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bb06d009
Commit
bb06d009
authored
Jan 18, 2023
by
ltqin
Browse files
add drop parameter in device
parent
5012068b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
15 deletions
+33
-15
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_train_xdl_cshuffle.hpp
...tched_multihead_attention_backward_train_xdl_cshuffle.hpp
+33
-15
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_train_xdl_cshuffle.hpp
View file @
bb06d009
...
@@ -76,7 +76,12 @@ __global__ void
...
@@ -76,7 +76,12 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
float
p_dropout
,
const
float
rp_dropout
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -97,13 +102,8 @@ __global__ void
...
@@ -97,13 +102,8 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
float
p_dropout
=
1
-
0.2
;
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
const
ushort
p_dropout_in_16bits
=
65536
*
p_dropout
;
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
0
;
const
index_t
block_id
=
get_block_1d_id
();
ck
::
philox
ph
(
seed
,
0
,
block_id
*
4
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
...
@@ -665,8 +665,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -665,8 +665,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
...
@@ -743,6 +743,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -743,6 +743,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o_
);
y_grid_desc_m_o_
);
}
}
p_dropout_
=
1.
f
-
p_drop
;
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
rp_dropout_
=
1.
f
/
p_dropout_
;
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
// Print();
// Print();
}
}
...
@@ -821,6 +828,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -821,6 +828,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
ushort
p_dropout_in_16bits_
;
GemmAccDataType
rp_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
};
};
// Invoker
// Invoker
...
@@ -895,7 +908,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -895,7 +908,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
);
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_
,
arg
.
rp_dropout_
,
arg
.
seed_
,
arg
.
offset_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
@@ -1036,7 +1054,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1036,7 +1054,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -1068,7 +1086,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1068,7 +1086,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
seed
};
seed
s
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -1108,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1108,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
DataType
*>
(
p_b
),
...
@@ -1140,7 +1158,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1140,7 +1158,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
seed
);
seed
s
);
}
}
// polymorphic
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment