Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9241ff62
"...composable_kernel.git" did not exist on "f99f419dfd07cbad25d8068cf7fed84ab501f0b3"
Commit
9241ff62
authored
Jan 16, 2023
by
guangzlu
Browse files
modified method to set offset in philox
parent
11eed39f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
51 additions
and
39 deletions
+51
-39
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_train.inc
...emm/run_batched_gemm_scale_softmax_gemm_permute_train.inc
+2
-1
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
...emm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
+2
-1
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+4
-9
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+20
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+18
-12
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
+2
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_train.inc
100755 → 100644
View file @
9241ff62
...
@@ -178,7 +178,8 @@ int run(int argc, char* argv[])
...
@@ -178,7 +178,8 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
0
);
// dropout ratio
0
,
// dropout ratio
{
0
,
64
});
// dropout random seed and offset
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
View file @
9241ff62
...
@@ -218,7 +218,8 @@ int run(int argc, char* argv[])
...
@@ -218,7 +218,8 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
0
);
// dropout ratio
0
,
// dropout ratio
{
0
,
448
});
// dropout random seed and offset
// specify workspace for problem_desc
// specify workspace for problem_desc
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
9241ff62
...
@@ -17,10 +17,7 @@ struct BlockwiseDropout
...
@@ -17,10 +17,7 @@ struct BlockwiseDropout
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
template
<
typename
CThreadBuffer
>
template
<
typename
CThreadBuffer
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
ph
)
ck
::
philox
ph
,
const
int
repeat_index
,
const
int
total_repeats
)
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
...
@@ -28,15 +25,13 @@ struct BlockwiseDropout
...
@@ -28,15 +25,13 @@ struct BlockwiseDropout
};
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
int
tid
=
get_thread_global_1d_id
();
int
philox_calls
=
tmp_size
/
8
;
unsigned
long
long
uni_subsequence
=
tid
*
total_repeats
*
philox_calls
+
repeat_index
*
philox_calls
;
ushort
tmp
[
tmp_size
];
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
)
,
(
uni_subsequence
+
i
)
);
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
}
}
block_sync_lds
();
block_sync_lds
();
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
9241ff62
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include <tuple>
#include "device_base.hpp"
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
...
@@ -117,7 +118,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
...
@@ -117,7 +118,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
=
0
;
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
9241ff62
...
@@ -129,7 +129,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
...
@@ -129,7 +129,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
=
0
;
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
View file @
9241ff62
...
@@ -69,8 +69,9 @@ __global__ void
...
@@ -69,8 +69,9 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
)
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -89,8 +90,8 @@ __global__ void
...
@@ -89,8 +90,8 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
index_t
block_id
=
get_block
_1d_id
();
const
index_t
global_thread_id
=
get_thread_global
_1d_id
();
ck
::
philox
ph
(
seed
,
0
,
block_id
*
4
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
...
@@ -478,7 +479,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -478,7 +479,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
unsigned
long
long
seed
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
...
@@ -527,8 +528,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -527,8 +528,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
seed_
(
seed
)
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc0_biases
;
...
@@ -554,6 +554,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -554,6 +554,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
std
::
cout
<<
"seed_"
<<
seed_
<<
std
::
endl
;
std
::
cout
<<
"offset_"
<<
offset_
<<
std
::
endl
;
}
}
void
Print
()
const
void
Print
()
const
...
@@ -619,6 +625,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -619,6 +625,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
ushort
p_dropout_in_16bits_
;
ushort
p_dropout_in_16bits_
;
GemmAccDataType
p_dropout_rescale_
;
GemmAccDataType
p_dropout_rescale_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
bool
is_dropout_
;
bool
is_dropout_
;
};
};
...
@@ -692,7 +699,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -692,7 +699,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
);
arg
.
seed_
,
arg
.
offset_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
@@ -846,7 +854,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -846,7 +854,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -874,7 +882,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -874,7 +882,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_dropout
,
p_dropout
,
seed
};
seed
s
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -910,7 +918,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -910,7 +918,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -938,7 +946,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -938,7 +946,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_dropout
,
p_dropout
,
seed
);
seed
s
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
View file @
9241ff62
...
@@ -46,15 +46,17 @@ __global__ void
...
@@ -46,15 +46,17 @@ __global__ void
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
)
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
0
,
block_id
*
4
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
@@ -519,13 +521,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -519,13 +521,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
unsigned
long
long
seed
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
}
seed_
(
seed
)
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
group_count_
=
problem_desc_vec
.
size
();
group_count_
=
problem_desc_vec
.
size
();
...
@@ -647,6 +648,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -647,6 +648,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
}
}
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
...
@@ -664,6 +668,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -664,6 +668,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
p_dropout_
;
float
p_dropout_
;
ushort
p_dropout_in_16bits_
;
ushort
p_dropout_in_16bits_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
GemmAccDataType
p_dropout_rescale_
;
GemmAccDataType
p_dropout_rescale_
;
bool
is_dropout_
;
bool
is_dropout_
;
};
};
...
@@ -726,7 +731,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -726,7 +731,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
);
arg
.
seed_
,
arg
.
offset_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
@@ -895,7 +901,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -895,7 +901,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
{
{
return
Argument
{
p_a_vec
,
return
Argument
{
p_a_vec
,
p_b_vec
,
p_b_vec
,
...
@@ -911,7 +917,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -911,7 +917,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_dropout
,
p_dropout
,
seed
};
seed
s
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -932,7 +938,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -932,7 +938,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seed
s
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
p_b_vec
,
p_b_vec
,
...
@@ -948,7 +954,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -948,7 +954,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_dropout
,
p_dropout
,
seed
);
seed
s
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
View file @
9241ff62
...
@@ -781,6 +781,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -781,6 +781,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
@@ -875,8 +876,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -875,8 +876,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
if
constexpr
(
IsDropout
)
// dropout
{
{
blockwise_dropout
.
ApplyDropout
(
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
ph
);
acc_thread_buf
,
ph
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
}
}
// TODO: may convert to log domain
// TODO: may convert to log domain
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment