Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4fec5ad3
Commit
4fec5ad3
authored
Oct 28, 2022
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into wmma_op
parents
24faa1fc
87fd1152
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1925 additions
and
1038 deletions
+1925
-1038
example/34_batchnorm/batchnorm_forward_nhwc.cpp
example/34_batchnorm/batchnorm_forward_nhwc.cpp
+222
-95
example/34_batchnorm/batchnorm_infer_impl.hpp
example/34_batchnorm/batchnorm_infer_impl.hpp
+27
-15
example/34_batchnorm/batchnorm_infer_nhwc.cpp
example/34_batchnorm/batchnorm_infer_nhwc.cpp
+35
-21
include/ck/ck.hpp
include/ck/ck.hpp
+5
-0
include/ck/tensor_description/tensor_space_filling_curve.hpp
include/ck/tensor_description/tensor_space_filling_curve.hpp
+6
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+36
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+39
-28
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
.../tensor_operation/gpu/device/device_batchnorm_forward.hpp
+13
-8
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
...ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+25
-19
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+235
-313
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+19
-13
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+316
-370
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+7
-24
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+711
-0
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+43
-37
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+36
-30
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+139
-56
No files found.
example/34_batchnorm/batchnorm_forward_nhwc.cpp
View file @
4fec5ad3
...
@@ -15,13 +15,9 @@
...
@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template
<
typename
InOutDataType
,
typename
AccDataType
>
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
<
InOutDataType
,
AccDataType
>
;
static
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
static
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
...
@@ -41,9 +37,10 @@ class BatchNormFwdArg
...
@@ -41,9 +37,10 @@ class BatchNormFwdArg
bool
updateMovingAverage
;
bool
updateMovingAverage
;
bool
saveMeanAndInvVariance
;
bool
saveMeanAndInvVariance
;
int
data_type
=
0
;
int
data_type
=
0
;
int
init_method
=
2
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
bool
use_multiblock_welford
=
false
;
public:
public:
void
show_usage
(
const
char
*
cmd
)
void
show_usage
(
const
char
*
cmd
)
...
@@ -68,6 +65,7 @@ class BatchNormFwdArg
...
@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)"
"value, 3=decimal value)"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Arg5: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
std
::
cout
<<
"Arg5: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
std
::
cout
<<
"Arg6: use multi-block welford (0=n0, 1=yes)"
<<
std
::
endl
;
};
};
int
processArgs
(
int
argc
,
char
*
argv
[])
int
processArgs
(
int
argc
,
char
*
argv
[])
...
@@ -110,14 +108,15 @@ class BatchNormFwdArg
...
@@ -110,14 +108,15 @@ class BatchNormFwdArg
};
};
};
};
if
(
optind
+
5
>
argc
)
if
(
optind
+
6
>
argc
)
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
data_type
=
std
::
atoi
(
argv
[
optind
++
]);
data_type
=
std
::
atoi
(
argv
[
optind
++
]);
updateMovingAverage
=
std
::
atoi
(
argv
[
optind
++
]);
updateMovingAverage
=
std
::
atoi
(
argv
[
optind
++
]);
saveMeanAndInvVariance
=
std
::
atoi
(
argv
[
optind
++
]);
saveMeanAndInvVariance
=
std
::
atoi
(
argv
[
optind
++
]);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
]));
time_kernel
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
++
]));
use_multiblock_welford
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
]));
if
(
data_type
!=
0
&&
data_type
!=
1
&&
data_type
!=
3
&&
data_type
!=
5
&&
data_type
!=
6
)
if
(
data_type
!=
0
&&
data_type
!=
1
&&
data_type
!=
3
&&
data_type
!=
5
&&
data_type
!=
6
)
return
(
-
1
);
return
(
-
1
);
...
@@ -128,7 +127,7 @@ class BatchNormFwdArg
...
@@ -128,7 +127,7 @@ class BatchNormFwdArg
using
namespace
ck
;
using
namespace
ck
;
template
<
typename
InOutDataType
,
typename
AccDataType
>
template
<
typename
InOutDataType
,
typename
AccDataType
,
bool
UseMultiblockInK
>
bool
bnorm_fwd_nhwc_test
(
bool
do_verification
,
bool
bnorm_fwd_nhwc_test
(
bool
do_verification
,
int
init_method
,
int
init_method
,
bool
time_kernel
,
bool
time_kernel
,
...
@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides
.
end
(),
scaleBiasMeanVarStrides
.
end
(),
i_scaleBiasMeanVarStrides
.
begin
());
i_scaleBiasMeanVarStrides
.
begin
());
int
result
=
0
;
using
PassThroughOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// used for saving meansquare
using
DeviceBatchNormFwdInstance
=
DeviceMem
workspace
(
sizeof
(
AccDataType
)
*
2
*
resultSaveMean_ref
.
mDesc
.
GetElementSpaceSize
()
+
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwdImpl
<
InOutDataType
,
128
);
InOutDataType
,
AccDataType
,
void
*
p_tmp_mean
=
workspace
.
GetDeviceBuffer
();
AccDataType
,
// ScaleDataType
void
*
p_tmp_meansquare
=
AccDataType
,
// BiasDataType
static_cast
<
char
*>
(
p_tmp_mean
)
+
AccDataType
,
// MeanVarDataType
(
sizeof
(
AccDataType
)
*
resultSaveMean_ref
.
mDesc
.
GetElementSpaceSize
()
+
63
)
/
64
*
64
;
PassThroughOp
,
// YElementwiseOp
Rank
,
result
=
bnorm_fwd
<
InOutDataType
,
AccDataType
,
Rank
,
NumReduceDim
,
false
>
(
NumReduceDim
,
time_kernel
,
UseMultiblockInK
,
updateMovingAverage
,
256
,
saveMeanAndInvVariance
,
16
,
{
0
,
1
,
2
},
16
,
1
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
;
auto
batchnorm_fwd
=
DeviceBatchNormFwdInstance
{};
auto
argument_ptr
=
batchnorm_fwd
.
MakeArgumentPointer
(
i_inOutLengths
,
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
averageFactor
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
epsilon
,
epsilon
,
PassThroughOp
{},
y_dev
.
GetDeviceBuffer
(),
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
p_tmp_mean
,
averageFactor
,
p_tmp_meansquare
);
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
if
(
result
<
0
)
if
(
!
batchnorm_fwd
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<<
std
::
endl
;
return
(
false
);
return
(
false
);
};
size_t
workspace_sz
=
batchnorm_fwd
.
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
batchnorm_fwd
.
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
auto
invoker_ptr
=
batchnorm_fwd
.
MakeInvokerPointer
();
if
(
time_kernel
)
{
float
avg_time
=
0.0
f
;
size_t
num_bytes
=
0
;
size_t
total_length
=
inOutLengths
[
0
]
*
inOutLengths
[
1
]
*
inOutLengths
[
2
]
*
inOutLengths
[
3
];
size_t
invariant_length
=
inOutLengths
[
3
];
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// inputing of x, scale, bias, outputing of y
num_bytes
+=
total_length
*
sizeof
(
InOutDataType
)
*
2
+
invariant_length
*
sizeof
(
AccDataType
)
*
2
;
// outputing of mean, inv-variance
num_bytes
+=
saveMeanAndInvVariance
?
invariant_length
*
sizeof
(
AccDataType
)
*
2
:
0
;
// updating of moving mean, variance
num_bytes
+=
updateMovingAverage
?
invariant_length
*
sizeof
(
AccDataType
)
*
4
:
0
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
else
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
<
InOutDataType
,
AccDataType
>
{};
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThroughOp
>
;
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
{};
auto
argument_ptr_ref
=
batchNormFwd_ref
.
MakeArgumentPointer
(
auto
argument_ptr_ref
=
batchNormFwd_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x
.
mData
.
data
(),
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
y_ref
.
mData
.
data
(),
0.1
,
// exponentialAverageFactor
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
// resultRunningMean
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
,
// resultRunningVariance
epsilon
,
epsilon
,
PassThroughOp
{},
y_ref
.
mData
.
data
(),
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
);
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
if
(
!
batchNormFwd_ref
.
IsSupportedArgument
(
argument_ptr_ref
.
get
()))
if
(
!
batchNormFwd_ref
.
IsSupportedArgument
(
argument_ptr_ref
.
get
()))
{
{
std
::
cout
std
::
cout
<<
"The runtime parameters seems not supported by the BatchNorm reference "
<<
"The runtime parameters seems not supported by the BatchNorm
instance, exiting!"
"
instance, exiting!"
<<
std
::
endl
;
<<
std
::
endl
;
return
(
-
2
);
return
(
false
);
};
};
auto
invoker_ptr_ref
=
batchNormFwd_ref
.
MakeInvokerPointer
();
auto
invoker_ptr_ref
=
batchNormFwd_ref
.
MakeInvokerPointer
();
...
@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if
(
saveMeanAndInvVariance
)
if
(
saveMeanAndInvVariance
)
{
{
using
ck
::
host_common
::
dumpBufferToFile
;
Tensor
<
AccDataType
>
resultSaveMean
(
scaleBiasMeanVarLengths
);
Tensor
<
AccDataType
>
resultSaveMean
(
scaleBiasMeanVarLengths
);
Tensor
<
AccDataType
>
resultSaveInvVariance
(
scaleBiasMeanVarLengths
);
Tensor
<
AccDataType
>
resultSaveInvVariance
(
scaleBiasMeanVarLengths
);
...
@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
...
@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
if
(
arg
.
data_type
==
0
)
if
(
arg
.
data_type
==
0
)
{
{
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
arg
.
init_method
,
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
time_kernel
,
arg
.
init_method
,
arg
.
inOutLengths
,
arg
.
time_kernel
,
arg
.
updateMovingAverage
,
arg
.
inOutLengths
,
arg
.
saveMeanAndInvVariance
,
arg
.
updateMovingAverage
,
averageFactor
,
arg
.
saveMeanAndInvVariance
,
epsilon
);
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
}
}
else
if
(
arg
.
data_type
==
1
)
else
if
(
arg
.
data_type
==
1
)
{
{
pass
=
bnorm_fwd_nhwc_test
<
float
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
arg
.
init_method
,
pass
=
bnorm_fwd_nhwc_test
<
float
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
time_kernel
,
arg
.
init_method
,
arg
.
inOutLengths
,
arg
.
time_kernel
,
arg
.
updateMovingAverage
,
arg
.
inOutLengths
,
arg
.
saveMeanAndInvVariance
,
arg
.
updateMovingAverage
,
averageFactor
,
arg
.
saveMeanAndInvVariance
,
epsilon
);
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
float
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
}
}
else
if
(
arg
.
data_type
==
3
)
else
if
(
arg
.
data_type
==
3
)
{
{
pass
=
bnorm_fwd_nhwc_test
<
int8_t
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
arg
.
init_method
,
pass
=
bnorm_fwd_nhwc_test
<
int8_t
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
time_kernel
,
arg
.
init_method
,
arg
.
inOutLengths
,
arg
.
time_kernel
,
arg
.
updateMovingAverage
,
arg
.
inOutLengths
,
arg
.
saveMeanAndInvVariance
,
arg
.
updateMovingAverage
,
averageFactor
,
arg
.
saveMeanAndInvVariance
,
epsilon
);
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
int8_t
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
}
}
else
if
(
arg
.
data_type
==
5
)
else
if
(
arg
.
data_type
==
5
)
{
{
pass
=
bnorm_fwd_nhwc_test
<
ck
::
bhalf_t
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
arg
.
init_method
,
pass
=
bnorm_fwd_nhwc_test
<
ck
::
bhalf_t
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
time_kernel
,
arg
.
init_method
,
arg
.
inOutLengths
,
arg
.
time_kernel
,
arg
.
updateMovingAverage
,
arg
.
inOutLengths
,
arg
.
saveMeanAndInvVariance
,
arg
.
updateMovingAverage
,
averageFactor
,
arg
.
saveMeanAndInvVariance
,
epsilon
);
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
ck
::
bhalf_t
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
}
}
else
if
(
arg
.
data_type
==
6
)
else
if
(
arg
.
data_type
==
6
)
{
{
pass
=
bnorm_fwd_nhwc_test
<
double
,
double
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
arg
.
init_method
,
pass
=
bnorm_fwd_nhwc_test
<
double
,
double
,
true
>
(
arg
.
do_verification
,
arg
.
time_kernel
,
arg
.
init_method
,
arg
.
inOutLengths
,
arg
.
time_kernel
,
arg
.
updateMovingAverage
,
arg
.
inOutLengths
,
arg
.
saveMeanAndInvVariance
,
arg
.
updateMovingAverage
,
averageFactor
,
arg
.
saveMeanAndInvVariance
,
epsilon
);
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
double
,
double
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
}
}
}
}
else
else
{
{
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
>
(
true
,
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
true
>
(
true
,
2
,
2
,
false
,
// don't time kernel
false
,
// don't time kernel
{
128
,
16
,
16
,
1024
},
{
128
,
16
,
6
,
512
},
true
,
true
,
false
,
true
,
averageFactor
,
averageFactor
,
epsilon
);
epsilon
);
pass
=
pass
&&
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
false
>
(
true
,
2
,
false
,
// don't time kernel
{
128
,
16
,
3
,
1024
},
true
,
true
,
averageFactor
,
epsilon
);
};
};
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
example/34_batchnorm/batchnorm_infer_impl.hpp
View file @
4fec5ad3
...
@@ -14,8 +14,12 @@
...
@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp"
#include "batchnorm_common.hpp"
template
<
typename
InOutDataType
,
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
ck
::
index_t
Rank
,
ck
::
index_t
Rank
,
ck
::
index_t
NumBatchNormReduceDim
,
ck
::
index_t
NumBatchNormReduceDim
,
bool
fastest_dim_is_reduced
=
false
>
bool
fastest_dim_is_reduced
=
false
>
...
@@ -26,7 +30,9 @@ int bnorm_infer(
...
@@ -26,7 +30,9 @@ int bnorm_infer(
const
std
::
array
<
ck
::
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_x
,
const
void
*
p_scale
,
const
void
*
p_scale
,
const
void
*
p_bias
,
const
void
*
p_bias
,
...
@@ -41,11 +47,11 @@ int bnorm_infer(
...
@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!"
);
"Invalid number of reduced dimensions for batchnorm!"
);
using
DeviceNormalizeInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
using
DeviceNormalizeInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
InOut
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
,
// x, mean,
ck
::
Tuple
<
X
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
,
// x, mean,
// variance,
// variance,
// scale,
// scale,
// bias,
// bias,
ck
::
Tuple
<
InOut
DataType
>
,
// y
ck
::
Tuple
<
Y
DataType
>
,
// y
NormalizeInInfer
,
NormalizeInInfer
,
Rank
,
Rank
,
2
,
// MPerthread
2
,
// MPerthread
...
@@ -53,14 +59,18 @@ int bnorm_infer(
...
@@ -53,14 +59,18 @@ int bnorm_infer(
ck
::
Sequence
<
1
>>
;
// scalarPerVector: y
ck
::
Sequence
<
1
>>
;
// scalarPerVector: y
auto
invariantDims
=
get_invariant_dims
<
Rank
,
NumBatchNormReduceDim
>
(
reduceDims
);
auto
invariantDims
=
get_invariant_dims
<
Rank
,
NumBatchNormReduceDim
>
(
reduceDims
);
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_scaleBiasMeanVarStrides
{
0
};
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_bnScaleStrides
{
0
};
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_bnBiasStrides
{
0
};
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_bnMeanVarStrides
{
0
};
int
i
=
0
;
int
i
=
0
;
for
(
auto
dim
:
invariantDims
)
for
(
auto
dim
:
invariantDims
)
{
{
assert
(
xyLengths
[
dim
]
==
bnScaleBiasMeanVarLengths
[
i
]);
assert
(
xyLengths
[
dim
]
==
bnScaleBiasMeanVarLengths
[
i
]);
aligned_scaleBiasMeanVarStrides
[
dim
]
=
bnScaleBiasMeanVarStrides
[
i
];
aligned_bnScaleStrides
[
dim
]
=
bnScaleStrides
[
i
];
aligned_bnBiasStrides
[
dim
]
=
bnBiasStrides
[
i
];
aligned_bnMeanVarStrides
[
dim
]
=
bnMeanVarStrides
[
i
];
i
++
;
i
++
;
};
};
...
@@ -84,10 +94,10 @@ int bnorm_infer(
...
@@ -84,10 +94,10 @@ int bnorm_infer(
auto
argument_ptr1
=
dev_normalize
.
MakeArgumentPointer
(
auto
argument_ptr1
=
dev_normalize
.
MakeArgumentPointer
(
xyLengths
,
xyLengths
,
{
xStrides
,
{
xStrides
,
aligned_
scaleBias
MeanVarStrides
,
aligned_
bn
MeanVarStrides
,
aligned_
scaleBias
MeanVarStrides
,
aligned_
bn
MeanVarStrides
,
aligned_
s
cale
BiasMeanVar
Strides
,
aligned_
bnS
caleStrides
,
aligned_
scaleBiasMeanVar
Strides
},
aligned_
bnBias
Strides
},
{
yStrides
},
{
yStrides
},
{
p_x
,
p_estimatedMean
,
p_estimatedVariance
,
p_scale
,
p_bias
},
{
p_x
,
p_estimatedMean
,
p_estimatedVariance
,
p_scale
,
p_bias
},
{
p_y
},
{
p_y
},
...
@@ -105,8 +115,10 @@ int bnorm_infer(
...
@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time
+=
invoker_ptr1
->
Run
(
argument_ptr1
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
avg_time
+=
invoker_ptr1
->
Run
(
argument_ptr1
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
num_bytes
+=
(
total_length
*
(
1
*
sizeof
(
InOutDataType
)
+
4
*
sizeof
(
AccDataType
))
+
num_bytes
+=
total_length
*
sizeof
(
XDataType
)
+
total_length
*
sizeof
(
InOutDataType
));
invariantLength
*
(
sizeof
(
ScaleDataType
)
+
sizeof
(
BiasDataType
)
+
2
*
sizeof
(
MeanVarDataType
))
+
total_length
*
sizeof
(
YDataType
);
if
(
time_kernel
)
if
(
time_kernel
)
{
{
...
...
example/34_batchnorm/batchnorm_infer_nhwc.cpp
View file @
4fec5ad3
...
@@ -18,11 +18,6 @@
...
@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp"
#include "batchnorm_infer_impl.hpp"
template
<
typename
InOutDataType
,
typename
AccDataType
>
using
ReferenceBatchNormInferInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
<
InOutDataType
,
AccDataType
>
;
static
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
static
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
{
"help"
,
no_argument
,
nullptr
,
'?'
},
{
"help"
,
no_argument
,
nullptr
,
'?'
},
...
@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int
result
=
0
;
int
result
=
0
;
result
=
bnorm_infer
<
InOutDataType
,
AccDataType
,
Rank
,
NumReduceDim
,
false
>
(
result
=
bnorm_infer
<
InOutDataType
,
time_kernel
,
InOutDataType
,
{
0
,
1
,
2
},
AccDataType
,
i_inOutLengths
,
AccDataType
,
i_inOutStrides
,
AccDataType
,
i_inOutStrides
,
AccDataType
,
i_scaleBiasMeanVarLengths
,
Rank
,
i_scaleBiasMeanVarStrides
,
NumReduceDim
,
x_dev
.
GetDeviceBuffer
(),
false
>
(
time_kernel
,
bnScale_dev
.
GetDeviceBuffer
(),
{
0
,
1
,
2
},
bnBias_dev
.
GetDeviceBuffer
(),
i_inOutLengths
,
epsilon
,
i_inOutStrides
,
estimatedMean_dev
.
GetDeviceBuffer
(),
i_inOutStrides
,
estimatedVariance_dev
.
GetDeviceBuffer
(),
i_scaleBiasMeanVarLengths
,
y_dev
.
GetDeviceBuffer
());
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
epsilon
,
estimatedMean_dev
.
GetDeviceBuffer
(),
estimatedVariance_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
());
if
(
result
<
0
)
if
(
result
<
0
)
return
(
false
);
return
(
false
);
...
@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
<
InOutDataType
,
AccDataType
>
{};
using
ReferenceBatchNormInferInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
;
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
{};
auto
argument_ptr_ref
=
auto
argument_ptr_ref
=
batchNormInfer_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
batchNormInfer_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
...
@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides
,
i_inOutStrides
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x
.
mData
.
data
(),
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
...
...
include/ck/ck.hpp
View file @
4fec5ad3
...
@@ -168,6 +168,11 @@
...
@@ -168,6 +168,11 @@
// tuning parameter
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace
ck
{
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/tensor_description/tensor_space_filling_curve.hpp
View file @
4fec5ad3
...
@@ -14,7 +14,8 @@ namespace ck {
...
@@ -14,7 +14,8 @@ namespace ck {
template
<
typename
TensorLengths
,
template
<
typename
TensorLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
typename
ScalarsPerAccess
>
// # of scalars per access in each dimension
typename
ScalarsPerAccess
,
bool
SnakeCurved
=
true
>
// # of scalars per access in each dimension
struct
SpaceFillingCurve
struct
SpaceFillingCurve
{
{
static
constexpr
index_t
nDim
=
TensorLengths
::
Size
();
static
constexpr
index_t
nDim
=
TensorLengths
::
Size
();
...
@@ -136,9 +137,10 @@ struct SpaceFillingCurve
...
@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index
ordered_idx
;
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
ordered_idx
(
idim
)
=
forward_sweep
[
idim
]
?
ordered_access_idx
[
idim
]
ordered_idx
(
idim
)
=
:
ordered_access_lengths
[
idim
]
-
1
-
!
SnakeCurved
||
forward_sweep
[
idim
]
ordered_access_idx
[
idim
];
?
ordered_access_idx
[
idim
]
:
ordered_access_lengths
[
idim
]
-
1
-
ordered_access_idx
[
idim
];
});
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
4fec5ad3
...
@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
Number
<
m0
>
{},
Number
<
n0
>
{},
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
...
@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
...
@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
m0
,
n0
,
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
View file @
4fec5ad3
...
@@ -24,7 +24,8 @@ template <typename ALayout,
...
@@ -24,7 +24,8 @@ template <typename ALayout,
typename
B0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
4fec5ad3
...
@@ -7,49 +7,60 @@
...
@@ -7,49 +7,60 @@
#include <vector>
#include <vector>
#include "device_base.hpp"
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
ALayout
,
template
<
index_t
NumDimG
,
typename
B0Layout
,
index_t
NumDimM
,
typename
B1Layout
,
index_t
NumDimN
,
typename
CPermuteNumDims_G_M_Gemm1N
,
// Sequence<>
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedGemmSoftmaxGemmPermute
:
public
BaseOperator
struct
DeviceBatchedGemmSoftmaxGemmPermute
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
MakeArgumentPointer
(
const
void
*
p_a
,
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
const
void
*
p_b0
,
const
void
*
p_b1
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_c
,
const
void
*
p_a
,
ck
::
index_t
M
,
const
void
*
p_b0
,
ck
::
index_t
N
,
const
void
*
p_b1
,
ck
::
index_t
K
,
void
*
p_c
,
ck
::
index_t
O
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
ck
::
index_t
Batch
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
std
::
vector
<
index_t
>
c_gs_ms_os_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
ck
::
index_t
StrideA
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
ck
::
index_t
StrideB0
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
ck
::
index_t
StrideB1
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
ck
::
index_t
BatchStrideA
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
ck
::
index_t
BatchStrideB0
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
ck
::
index_t
BatchStrideB1
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
B0ElementwiseOperation
b0_element_op
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
Acc0ElementwiseOperation
acc0_element_op
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
B1ElementwiseOperation
b1_element_op
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
CElementwiseOperation
c_element_op
)
=
0
;
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
View file @
4fec5ad3
...
@@ -13,31 +13,36 @@ namespace ck {
...
@@ -13,31 +13,36 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
struct
DeviceBatchNormFwd
:
public
BaseOperator
struct
DeviceBatchNormFwd
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnScale
,
const
void
*
bnBias
,
const
void
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
,
double
exponentialAverageFactor
,
double
exponentialAverageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
,
void
*
resultRunningVariance
)
=
0
;
double
epsilon
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
>>
;
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
View file @
4fec5ad3
...
@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
...
@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnScale
,
const
void
*
bnBias
,
const
void
*
bnBias
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
4fec5ad3
...
@@ -7,46 +7,50 @@
...
@@ -7,46 +7,50 @@
#include <vector>
#include <vector>
#include "device_base.hpp"
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
ALayout
,
template
<
index_t
NumDimG
,
typename
B0Layout
,
index_t
NumDimM
,
typename
B1Layout
,
index_t
NumDimN
,
typename
CPermuteNumDims_G_M_Gemm1N
,
// Sequence<>
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceGroupedGemmSoftmaxGemmPermute
:
public
BaseOperator
struct
DeviceGroupedGemmSoftmaxGemmPermute
:
public
BaseOperator
{
{
struct
ProblemDesc
struct
ProblemDesc
{
{
// Overall problem shape
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
index_t
M
;
std
::
vector
<
index_t
>
a_gs_ms_ks_strides
;
index_t
N
;
index_t
K
;
index_t
O
;
index_t
Batch
;
// Stride for A/B0/B1; layout determined by template args
std
::
vector
<
index_t
>
b0_gs_ns_ks_lengths
;
index_t
StrideA
;
std
::
vector
<
index_t
>
b0_gs_ns_ks_strides
;
index_t
StrideB0
;
index_t
StrideB1
;
std
::
vector
<
index_t
>
b1_gs_os_ns_lengths
;
index_t
BatchStrideA
;
std
::
vector
<
index_t
>
b1_gs_os_ns_strides
;
index_t
BatchStrideB0
;
index_t
BatchStrideB1
;
// Lengths and strides for output C
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc1_biases_gs_ms_os_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc1_biases_gs_ms_os_strides
;
};
};
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
...
@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
B0ElementwiseOperation
b0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4fec5ad3
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -54,9 +55,8 @@ __global__ void
...
@@ -54,9 +55,8 @@ __global__ void
index_t
right
=
group_count
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
arg_ptr
[
group_id
].
block_start_
&&
while
(
block_id
<
arg_ptr
[
group_id
].
block_end_
))
&&
(
!
(
block_id
>=
arg_ptr
[
group_id
].
block_start_
&&
block_id
<
arg_ptr
[
group_id
].
block_end_
)))
left
<=
right
)
{
{
if
(
block_id
<
arg_ptr
[
group_id
].
block_start_
)
if
(
block_id
<
arg_ptr
[
group_id
].
block_start_
)
{
{
...
@@ -114,14 +114,17 @@ __global__ void
...
@@ -114,14 +114,17 @@ __global__ void
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
// ^^^^^^^^^^^ (Acc1)
template
<
typename
ALayout
,
template
<
index_t
NumDimG
,
typename
BLayout
,
// B0Layout
index_t
NumDimM
,
typename
B1Layout
,
index_t
NumDimN
,
typename
CPermuteNumDims_G_M_Gemm1N
,
// Sequence<NumDimG, NumDimM, NumDimGemm1N>
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -130,6 +133,10 @@ template <typename ALayout,
...
@@ -130,6 +133,10 @@ template <typename ALayout,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -170,297 +177,152 @@ template <typename ALayout,
...
@@ -170,297 +177,152 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
MaskOutUpperTriangle
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct
DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
:
public
DeviceGroupedGemmSoftmaxGemmPermute
<
ALayout
,
:
public
DeviceGroupedGemmSoftmaxGemmPermute
<
NumDimG
,
BLayout
,
NumDimM
,
B1Layout
,
NumDimN
,
CPermuteNumDims_G_M_Gemm1N
,
NumDimK
,
NumDimO
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
,
MaskingSpec
>
{
{
using
DeviceOp
=
DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
;
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
using
ProblemDesc
=
"Number of dimension must be greater than 0"
);
typename
DeviceGroupedGemmSoftmaxGemmPermute
<
ALayout
,
BLayout
,
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
B1Layout
,
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
// TODO ANT: implement bias combination
BDataType
,
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
B1DataType
,
CDataType
,
#if 0
AElementwiseOperation
,
// TODO ANT: use alias
BElementwiseOperation
,
static constexpr index_t NumDimGemm0M = NumDimM;
AccElementwiseOperation
,
static constexpr index_t NumDimGemm0N = NumDimN;
B1ElementwiseOperation
,
static constexpr index_t NumDimGemm0K = NumDimK;
CElementwiseOperation
>::
ProblemDesc
;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
;
using
ProblemDesc
=
typename
DeviceGroupedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>::
ProblemDesc
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
matrix_padder
=
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpec
,
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
ASpec
,
{
BSpec
,
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
B1Spec
,
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
CSpec
>
;
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
make_tuple
(
StrideA
,
I1
));
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
auto
b1_grid_desc_nraw_kraw
=
[
&
]()
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>::
value
)
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
{
Number
<
AK1
>
{});
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
b1_grid_desc_n_k
=
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc_nraw_kraw
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
{
constexpr
index_t
NumDimG
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I0
);
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
constexpr
index_t
NumDimM
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I1
);
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths_vec
,
b_gs_ns_ks_strides_vec
),
constexpr
index_t
NumDimN
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I2
);
// NumDimGemm1N
Number
<
BK1
>
{});
assert
(
c_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
c_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
c_ms_ns_lengths
=
to_tuple
(
c_gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
c_gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
c_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
c_ms_ns_lengths
,
nDimIds
);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const
auto
c_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
c_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
c_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
static
auto
Make
C
GridDescriptor_
G_M_N
(
const
std
::
vector
<
index_t
>&
c
_gs_
ms_n
s_lengths_vec
,
Make
B1
GridDescriptor_
BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1
_gs_
gemm1ns_gemm1k
s_lengths_vec
,
const
std
::
vector
<
index_t
>&
c
_gs_
ms_n
s_strides_vec
)
const
std
::
vector
<
index_t
>&
b1
_gs_
gemm1ns_gemm1k
s_strides_vec
)
{
{
constexpr
index_t
NumDimG
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I0
);
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
constexpr
index_t
NumDimM
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I1
);
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
constexpr
index_t
NumDimN
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I2
);
// NumDimGemm1N
b1_gs_gemm1ns_gemm1ks_strides_vec
),
Number
<
B1K1
>
{});
assert
(
c_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
c_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
c_gs_ms_ns_lengths
=
to_tuple
(
c_gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_gs_ms_ns_strides
=
to_tuple
(
c_gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
c_gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
c_gs_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
c_gs_ms_ns_lengths
,
nDimIds
);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
c_grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
c_grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
c_grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// this desc is only for calculating batch offset so no padding needed
return
c_grid_desc_g_mraw_nraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
MakeCGridDescriptor_G_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
// to track the points which need to be set to -inf on C0
constexpr
static
auto
make_MaskOutPredicate
()
// Note: no need to reset M padding value, because they will not be stored out.
struct
C0MatrixMask
{
{
C0MatrixMask
(
index_t
NRaw
)
:
NRaw_
(
NRaw
)
{}
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
__host__
__device__
bool
IsUpperTriangle
(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
{
return
n
>=
NRaw_
;
return
MaskDisabledPredicate
{}
;
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
__host__
__device__
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
{
return
Is
UpperTriangle
(
m
,
n
)
||
IsNOutOfBound
(
n
)
;
return
MaskOut
UpperTriangle
Predicate
{}
;
}
}
}
private:
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
// index_t MRaw_;
index_t
NRaw_
;
};
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
{
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
index_t
BatchStrideB
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
index_t
BatchStrideB1
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
CGridDesc_G_M_N
c_grid_desc_g_m_n
)
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
:
BatchStrideA_
(
BatchStrideA
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
BatchStrideB_
(
BatchStrideB
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
BatchStrideB1_
(
BatchStrideB1
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
{
{
}
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
)
);
}
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
)
);
}
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB1_
);
return
b1_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
)
);
}
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
...
@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
private:
private:
index_t
BatchStrideA
_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k
_
;
index_t
BatchStrideB
_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k
_
;
index_t
BatchStrideB1
_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k
_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
};
};
...
@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct
GroupDeviceArg
struct
GroupDeviceArg
{
{
// problem definiton
// lengths for the last dimensions of overall problem for sanity check of vector load/store
index_t
M
;
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
index_t
N
;
index_t
K
;
index_t
O
;
// Strides for the last dimensions of C for sanity check of vector load/store
// strides for the last dimensions of each tensor for sanity check of vector load/store
index_t
c_extent_lowest_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
index_t
c_stride_lowest_
;
std
::
vector
<
index_t
>
b_nz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_kz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
// for gridwise gemm check
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
};
...
@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
// TODO ANT: implement bias addition
group_count_
=
problem_desc_vec
.
size
();
group_count_
=
problem_desc_vec
.
size
();
if
(
!
(
group_count_
==
p_a_vec
.
size
()
&&
group_count_
==
p_b_vec
.
size
()
&&
if
(
!
(
group_count_
==
p_a_vec
.
size
()
&&
group_count_
==
p_b_vec
.
size
()
&&
...
@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
throw
std
::
runtime_error
(
"wrong! group_count_ != a/b/b1/c_vec.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != a/b/b1/c_vec.size"
);
}
}
if
(
!
(
p_acc0_biases_vec
.
size
()
==
p_acc1_biases_vec
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! acc0_bias_vec.size != acc1_bias_vec.size"
);
}
grid_size_
=
0
;
grid_size_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
...
@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
auto
p_b1_grid
=
static_cast
<
const
B1DataType
*>
(
p_b1_vec
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
B1DataType
*>
(
p_b1_vec
[
i
]);
const
auto
p_c_grid
=
static_cast
<
CDataType
*>
(
p_c_vec
[
i
]);
const
auto
p_c_grid
=
static_cast
<
CDataType
*>
(
p_c_vec
[
i
]);
const
auto
a_grid_desc_ak0_m_ak1
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
problem_desc_vec
[
i
].
M
,
problem_desc_vec
[
i
].
K
,
problem_desc_vec
[
i
].
StrideA
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem_desc_vec
[
i
].
K
,
problem_desc_vec
[
i
].
N
,
problem_desc_vec
[
i
].
StrideB0
);
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc_vec
[
i
].
N
,
problem_desc_vec
[
i
].
O
,
problem_desc_vec
[
i
].
StrideB1
);
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
const
auto
b1_grid_desc_bk0_n_bk1
=
MakeB1GridDescriptor_BK0_N_BK1
(
problem_desc_vec
[
i
].
c_gs_ms_os_lengths
,
problem_desc_vec
[
i
].
c_gs_ms_os_strides
);
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
a_grid_desc_g_m_k
=
Transform
::
MakeAGridDescriptor_G_M_K
(
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
index_t
grid_size_grp
=
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
*
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
problem_desc_vec
[
i
].
Batch
;
const
index_t
grid_size_grp
=
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
*
batch_count
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
// batch stride
// batch stride
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
MakeCGridDescriptor_G_M_N
(
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
);
problem_desc_vec
[
i
].
c_gs_ms_os_lengths
,
problem_desc_vec
[
i
].
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
problem_desc_vec
[
i
].
BatchStrideA
,
problem_desc_vec
[
i
].
BatchStrideB0
,
problem_desc_vec
[
i
].
BatchStrideB1
,
c_grid_desc_g_m_n
);
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
problem_desc_vec
[
i
].
N
);
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
)
);
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if
(
!
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc0_biases_gs_ms_ns_strides
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_lengths
.
size
()
==
NumAcc1Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_strides
.
size
()
==
NumAcc1Bias
))
{
throw
std
::
runtime_error
(
"wrong! number of biases in function argument does not "
"match that in template argument"
);
}
group_kernel_args_
.
push_back
({
p_a_grid
,
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
...
@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
BlockStart
,
BlockStart
,
BlockEnd
});
BlockEnd
});
group_device_args_
.
push_back
({
problem_desc_vec
[
i
].
M
,
group_device_args_
.
push_back
(
problem_desc_vec
[
i
].
N
,
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc_vec
[
i
].
K
,
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
problem_desc_vec
[
i
].
O
,
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
problem_desc_vec
[
i
].
c_gs_ms_os_lengths
.
back
(),
problem_desc
.
b1_gs_os_ns_lengths
[
NumDimG
+
NumDimO
-
1
]},
problem_desc_vec
[
i
].
c_gs_ms_os_strides
.
back
(),
{
problem_desc
.
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
c_grid_desc_m_n
});
problem_desc
.
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
{
problem_desc
.
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
],
problem_desc
.
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
{
problem_desc
.
b1_gs_os_ns_strides
[
NumDimG
+
NumDimO
-
1
],
problem_desc
.
b1_gs_os_ns_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
});
}
}
}
}
...
@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
false
;
return
false
;
}
}
// TODO ANT: Check if tensor specialization & strides mismatch
bool
all_has_main_k_block_loop
=
true
;
bool
all_has_main_k_block_loop
=
true
;
bool
some_has_main_k_block_loop
=
false
;
bool
some_has_main_k_block_loop
=
false
;
...
@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when
// Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds
// part of vector is out of bounds
const
auto
MRaw
=
device_arg
.
M
;
const
auto
M
z
Raw
=
device_arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]
;
const
auto
NRaw
=
device_arg
.
N
;
const
auto
N
z
Raw
=
device_arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]
;
const
auto
KRaw
=
device_arg
.
K
;
const
auto
K
z
Raw
=
device_arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
2
]
;
const
auto
Gemm1NRaw
=
device_arg
.
O
;
const
auto
Gemm1N
z
Raw
=
device_arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
]
;
// Check scalar per vector requirement
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b_extent_lowest
=
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
NzRaw
:
Gemm1NzRaw
;
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
device_arg
.
c_extent_lowest_
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
...
@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
false
;
return
false
;
}
}
// Check vector store requirement; assumes last dimension in N to be contiguous
// Check vector load/store requirement
if
(
device_arg
.
c_stride_lowest_
!=
1
)
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
device_arg
.
a_mz_kz_strides_
[
1
]
:
device_arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
device_arg
.
b_nz_kz_strides_
[
1
]
:
device_arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
device_arg
.
b1_nz_kz_strides_
[
1
]
:
device_arg
.
b1_nz_kz_strides_
[
0
];
const
auto
c_stride_lowest
=
device_arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
{
return
false
;
return
false
;
}
}
...
@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_vec
,
p_b_vec
,
p_b1_vec
,
p_b1_vec
,
p_c_vec
,
p_c_vec
,
p_acc0_biases_vec
,
p_acc1_biases_vec
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a_vec
,
std
::
unique_ptr
<
BaseArgument
>
std
::
vector
<
const
void
*>
p_b_vec
,
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
std
::
vector
<
void
*>
p_c_vec
,
AElementwiseOperation
a_element_op
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
BElementwiseOperation
b_element_op
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
AccElementwiseOperation
acc_element_op
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
B1ElementwiseOperation
b1_element_op
,
AElementwiseOperation
a_element_op
,
CElementwiseOperation
c_element_op
)
override
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
p_b_vec
,
p_b_vec
,
p_b1_vec
,
p_b1_vec
,
p_c_vec
,
p_c_vec
,
p_acc0_biases_vec
,
p_acc1_biases_vec
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
", "
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
4fec5ad3
...
@@ -3,27 +3,30 @@
...
@@ -3,27 +3,30 @@
#pragma once
#pragma once
#include <
vector
>
#include <
array
>
#include <memory>
#include <memory>
#include <iostream>
#include "ck/utility/common_header.hpp"
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
struct
DeviceReduce
:
public
BaseOperator
struct
DeviceReduce
:
public
BaseOperator
{
{
static
constexpr
index_t
NumOutDim
=
(
Rank
-
NumReduceDim
==
0
)
?
1
:
Rank
-
NumReduceDim
;
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
array
<
index_t
,
NumOutDim
>
outLengths
,
const
std
::
vector
<
index_t
>
outStrides
,
const
std
::
array
<
index_t
,
NumOutDim
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
...
@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
...
@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
template
<
index_t
Rank
,
using
DeviceReducePtr
=
index_t
NumReduceDim
,
std
::
unique_ptr
<
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>>
;
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
4fec5ad3
...
@@ -130,8 +130,11 @@ namespace device {
...
@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
//
// Detail- Packed tensor satisfies
// Detail- Packed tensor satisfies
// stride_0 = 1
// stride_0 = 1
...
@@ -147,7 +150,7 @@ namespace device {
...
@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
//
// Might need to expose dimension order to the interface to fully support
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed
.
// TensorSpecialization::Packed
in a traditional sense of "packed" tensor
template
<
index_t
NumDimG
,
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4fec5ad3
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -116,14 +117,17 @@ __global__ void
...
@@ -116,14 +117,17 @@ __global__ void
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
// ^^^^^^^^^^^ (Acc1)
template
<
typename
ALayout
,
template
<
index_t
NumDimG
,
typename
BLayout
,
// B0Layout
index_t
NumDimM
,
typename
B1Layout
,
index_t
NumDimN
,
typename
CPermuteNumDims_G_M_Gemm1N
,
// Sequence<NumDimG, NumDimM, NumDimGemm1N>
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -132,6 +136,10 @@ template <typename ALayout,
...
@@ -132,6 +136,10 @@ template <typename ALayout,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -172,283 +180,135 @@ template <typename ALayout,
...
@@ -172,283 +180,135 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
MaskOutUpperTriangle
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
BLayout
,
NumDimM
,
B1Layout
,
NumDimN
,
CPermuteNumDims_G_M_Gemm1N
,
NumDimK
,
NumDimO
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
,
MaskingSpec
>
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
matrix_padder
=
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpec
,
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
{
Number
<
AK1
>
{});
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b1_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
b1_grid_desc_n_k
=
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc_nraw_kraw
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
{
constexpr
index_t
NumDimG
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I0
);
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
constexpr
index_t
NumDimM
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I1
);
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths_vec
,
b_gs_ns_ks_strides_vec
),
constexpr
index_t
NumDimN
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I2
);
// NumDimGemm1N
Number
<
BK1
>
{});
assert
(
c_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
c_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
c_ms_ns_lengths
=
to_tuple
(
c_gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
c_gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
c_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
c_ms_ns_lengths
,
nDimIds
);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const
auto
c_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
c_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
c_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
static
auto
Make
C
GridDescriptor_
G_M_N
(
const
std
::
vector
<
index_t
>&
c
_gs_
ms_n
s_lengths_vec
,
Make
B1
GridDescriptor_
BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1
_gs_
gemm1ns_gemm1k
s_lengths_vec
,
const
std
::
vector
<
index_t
>&
c
_gs_
ms_n
s_strides_vec
)
const
std
::
vector
<
index_t
>&
b1
_gs_
gemm1ns_gemm1k
s_strides_vec
)
{
{
constexpr
index_t
NumDimG
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I0
);
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
constexpr
index_t
NumDimM
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I1
);
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
constexpr
index_t
NumDimN
=
CPermuteNumDims_G_M_Gemm1N
::
At
(
I2
);
// NumDimGemm1N
b1_gs_gemm1ns_gemm1ks_strides_vec
),
Number
<
B1K1
>
{});
assert
(
c_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
c_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
c_gs_ms_ns_lengths
=
to_tuple
(
c_gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_gs_ms_ns_strides
=
to_tuple
(
c_gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
c_gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
c_gs_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
c_gs_ms_ns_lengths
,
nDimIds
);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
c_grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
c_grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
c_grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// this desc is only for calculating batch offset so no padding needed
return
c_grid_desc_g_mraw_nraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
MakeCGridDescriptor_G_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
// to track the points which need to be set to -inf on C0
constexpr
static
auto
make_MaskOutPredicate
()
// Note: no need to reset M padding value, because they will not be stored out.
struct
C0MatrixMask
{
{
C0MatrixMask
(
index_t
NRaw
)
:
NRaw_
(
NRaw
)
{}
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
__host__
__device__
bool
IsUpperTriangle
(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
{
return
n
>=
NRaw_
;
return
MaskDisabledPredicate
{}
;
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
__host__
__device__
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
{
return
Is
UpperTriangle
(
m
,
n
)
||
IsNOutOfBound
(
n
)
;
return
MaskOut
UpperTriangle
Predicate
{}
;
}
}
}
private:
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
// index_t MRaw_;
index_t
NRaw_
;
};
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
{
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
index_t
BatchStrideB
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
index_t
BatchStrideB1
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
CGridDesc_G_M_N
c_grid_desc_g_m_n
)
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
:
BatchStrideA_
(
BatchStrideA
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
BatchStrideB_
(
BatchStrideB
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
BatchStrideB1_
(
BatchStrideB1
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
{
{
}
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
)
);
}
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
)
);
}
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB1_
);
return
b1_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
)
);
}
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
...
@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
private:
private:
index_t
BatchStrideA
_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k
_
;
index_t
BatchStrideB
_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k
_
;
index_t
BatchStrideB1
_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k
_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
};
};
...
@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// Argument
// Argument
// FIXME: constness
// FIXME: constness
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
BDataType
*
p_b_grid
,
const
ADataType
*
p_a_grid
,
const
B1DataType
*
p_b1_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
const
B1DataType
*
p_b1_grid
,
index_t
MRaw
,
CDataType
*
p_c_grid
,
index_t
NRaw
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
index_t
KRaw
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
index_t
Gemm1NRaw
,
// = ORaw
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
index_t
Batch
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
std
::
vector
<
index_t
>
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
std
::
vector
<
index_t
>
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
index_t
StrideA
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
index_t
StrideB
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
index_t
StrideB1
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
index_t
BatchStrideA
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
index_t
BatchStrideB
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
index_t
BatchStrideB1
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
BElementwiseOperation
b_element_op
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
AccElementwiseOperation
acc_element_op
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
B1ElementwiseOperation
b1_element_op
,
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
CElementwiseOperation
c_element_op
)
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_ak0_m_ak1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b1_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
NRaw
,
Gemm1NRaw
,
StrideB1
)},
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
c_gs_ms_gemm1ns_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
DeviceOp
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
b1_gs_gemm1ns_gemm1ks_lengths
[
NumDimG
+
NumDimO
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b_nz_kz_strides_
{
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
b1_nz_kz_strides_
{
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
-
1
],
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
c_grid_desc_g_m_n_
},
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
}
c0_matrix_mask_
{
NRaw
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
},
c_extent_lowest_
{
c_gs_ms_gemm1ns_lengths
.
back
()},
c_stride_lowest_
{
c_gs_ms_gemm1ns_strides
.
back
()}
{
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
ignore
=
acc0_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
...
@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
}
}
// private:
void
Print
()
const
{
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_n_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
}
// pointers
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
const
B1DataType
*
p_b1_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// check C0 masking and padding
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
C0MatrixMask
c0_matrix_mask_
;
// For robust IsSupportedArgument() check
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_m_n_k_o_
;
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
index_t
c_extent_lowest_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
index_t
c_stride_lowest_
;
std
::
vector
<
index_t
>
b_nz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_kz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
};
};
// Invoker
// Invoker
...
@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong!
GridwiseGemm has invalid setting
"
);
throw
std
::
runtime_error
(
"wrong!
unsupported argument
"
);
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
...
@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if 0
arg.Print();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
{
return
false
;
return
false
;
}
}
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
c_gemm1n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
{
return
false
;
return
false
;
...
@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// vector is out of bounds
const
auto
MRaw
=
arg
.
raw_lengths_m_n_k_o_
[
0
];
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
1
];
const
auto
MzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
const
auto
KzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
2
];
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
// Check scalar per vector requirement
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b_extent_lowest
=
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
NzRaw
:
Gemm1NzRaw
;
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
arg
.
c_extent_lowest_
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
...
@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
false
;
return
false
;
}
}
// Check vector store requirement; assumes last dimension in N to be contiguous
// Check vector load/store requirement
if
(
arg
.
c_stride_lowest_
!=
1
)
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_kz_strides_
[
1
]
:
arg
.
b1_nz_kz_strides_
[
0
];
const
auto
c_stride_lowest
=
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
{
return
false
;
return
false
;
}
}
...
@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
BDataType
*
p_b
,
const
ADataType
*
p_a
,
const
B1DataType
*
p_b1
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
B1DataType
*
p_b1
,
index_t
MRaw
,
CDataType
*
p_c
,
index_t
NRaw
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
index_t
KRaw
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
index_t
Gemm1NRaw
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
index_t
Batch
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
std
::
vector
<
index_t
>
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
std
::
vector
<
index_t
>
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
index_t
StrideA
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
index_t
StrideB
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
index_t
StrideB1
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
index_t
BatchStrideA
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
index_t
BatchStrideB
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
index_t
BatchStrideB1
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
BElementwiseOperation
b_element_op
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
AccElementwiseOperation
acc_element_op
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
B1ElementwiseOperation
b1_element_op
,
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
CElementwiseOperation
c_element_op
)
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_b1
,
p_b1
,
p_c
,
p_c
,
MRaw
,
p_acc0_biases
,
NRaw
,
p_acc1_biases
,
KRaw
,
a_gs_ms_ks_lengths
,
Gemm1NRaw
,
a_gs_ms_ks_strides
,
Batch
,
b_gs_ns_ks_lengths
,
c_gs_ms_gemm1ns_lengths
,
b_gs_ns_ks_strides
,
c_gs_ms_gemm1ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
StrideA
,
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
StrideB
,
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
StrideB1
,
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
BatchStrideA
,
acc0_biases_gs_ms_ns_lengths
,
BatchStrideB
,
acc0_biases_gs_ms_ns_strides
,
BatchStrideB1
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
...
@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// polymorphic
// polymorphic
// FIXME: constness
// FIXME: constness
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
const
void
*
p_b1
,
const
void
*
p_b1
,
void
*
p_c
,
void
*
p_c
,
index_t
MRaw
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
index_t
NRaw
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
index_t
KRaw
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
index_t
Gemm1NRaw
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
index_t
Batch
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
std
::
vector
<
index_t
>
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
std
::
vector
<
index_t
>
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
index_t
StrideA
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
index_t
StrideB
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
index_t
StrideB1
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
index_t
BatchStrideA
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
index_t
BatchStrideB
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
index_t
BatchStrideB1
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
AElementwiseOperation
a_element_op
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
BElementwiseOperation
b_element_op
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
AccElementwiseOperation
acc_element_op
,
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
B1ElementwiseOperation
b1_element_op
,
AElementwiseOperation
a_element_op
,
CElementwiseOperation
c_element_op
)
override
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
MRaw
,
p_acc0_biases
,
// cast in struct Argument
NRaw
,
p_acc1_biases
,
// cast in struct Argument
KRaw
,
a_gs_ms_ks_lengths
,
Gemm1NRaw
,
a_gs_ms_ks_strides
,
Batch
,
b_gs_ns_ks_lengths
,
c_gs_ms_gemm1ns_lengths
,
b_gs_ns_ks_strides
,
c_gs_ms_gemm1ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
StrideA
,
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
StrideB
,
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
StrideB1
,
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
BatchStrideA
,
acc0_biases_gs_ms_ns_lengths
,
BatchStrideB
,
acc0_biases_gs_ms_ns_strides
,
BatchStrideB1
,
acc1_biases_gs_ms_gemm1ns_lengths
,
acc1_biases_gs_ms_gemm1ns_strides
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
...
@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
", "
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
4fec5ad3
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
...
@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
,
MaskOutUpperTriangle
>
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
;
...
@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
}
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct
C0MatrixMask
{
C0MatrixMask
(
index_t
NRaw
)
:
NRaw_
(
NRaw
)
{}
__host__
__device__
bool
IsUpperTriangle
(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__host__
__device__
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
return
IsUpperTriangle
(
m
,
n
)
||
IsNOutOfBound
(
n
);
}
private:
// index_t MRaw_;
index_t
NRaw_
;
};
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
{
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
...
@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0MatrixMask
=
conditional_t
<
MaskOutUpperTriangle
,
C0MatrixMask_impl
<
MaskOutUpperTrianglePredicate
>
,
C0MatrixMask_impl
<
MaskDisabledPredicate
>>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeXY2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>&
xyStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleXYLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleXYStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
xyStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleXYLengths
,
tupleXYStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumBatchNormReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
workSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
workSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_g_padded
=
transform_tensor_descriptor
(
grid_desc_m_g
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_g_padded
);
};
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
math
::
integer_least_multiple
(
reduceLength
,
KThreadClusterSize
)
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeScaleBiasMeanVar1dDescriptor
(
const
std
::
array
<
index_t
,
NumInvariantDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>&
strides
)
{
const
auto
tupleLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_m
=
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_m_padded
);
};
using
XYGridDesc_M_K
=
decltype
(
MakeXY2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
ScaleBiasMeanVarGridDesc_M
=
decltype
(
MakeScaleBiasMeanVar1dDescriptor
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
p_scale
,
const
BiasDataType
*
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
double
epsilon
,
YDataType
*
p_y
,
MeanVarDataType
*
resultSaveMean
,
MeanVarDataType
*
resultSaveInvVariance
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_scale_
(
p_scale
),
p_bias_
(
p_bias
),
y_elementwise_op_
(
y_elementwise_op
),
p_y_
(
p_y
),
resultSaveMean_
(
resultSaveMean
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
yStrides
,
reduceDims
);
std
::
tie
(
invariant_length_
,
reduce_length_
)
=
get_2d_lengths
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
updateMovingAverage_
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
saveMeanInvVariance_
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance_
!=
nullptr
);
if
(
UseMultiblockInK
)
{
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 128
if
(
testBlkGroupSize
<=
128
)
break
;
iterations
++
;
};
blkGroupSize_
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration_
=
iterations
;
}
else
{
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_length_
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize_
=
(
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
blkGroupSize_
;
x_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
scale_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnScaleStrides_
);
bias_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnBiasStrides_
);
mean_var_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnMeanVarStrides_
);
}
AccDataType
epsilon_
;
AccDataType
averageFactor_
;
bool
updateMovingAverage_
;
bool
saveMeanInvVariance_
;
std
::
array
<
index_t
,
Rank
>
xyLengths_
;
std
::
array
<
index_t
,
Rank
>
xStrides_
;
std
::
array
<
index_t
,
Rank
>
yStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
p_scale_
;
const
BiasDataType
*
p_bias_
;
const
YElementwiseOp
y_elementwise_op_
;
YDataType
*
p_y_
;
MeanVarDataType
*
resultSaveMean_
;
MeanVarDataType
*
resultSaveInvVariance_
;
MeanVarDataType
*
resultRunningMean_
;
MeanVarDataType
*
resultRunningVariance_
;
long_index_t
invariant_length_
;
long_index_t
reduce_length_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
XYGridDesc_M_K
x_grid_desc_m_k_
;
XYGridDesc_M_K
y_grid_desc_m_k_
;
ScaleBiasMeanVarGridDesc_M
scale_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
bias_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
mean_var_grid_desc_m_
;
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// workspace for welford intermediate mean
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
}
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
workspace_variance_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
};
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0
;
if
(
UseMultiblockInK
&&
arg
.
blkGroupSize_
>
1
)
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForMultiblockWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
const
auto
mean_var_count_grid_desc_m_g
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountOutputMG2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
const
auto
mean_var_count_grid_desc_m_k
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountInputMK2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordSecondHalfBatchNormForwardFinal_
=
GridwiseWelfordSecondHalfBatchNormForwardFinal
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
index_t
numMeanVarCountBlockTileIteration
=
(
arg
.
blkGroupSize_
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
numMeanVarCountBlockTileIteration
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
}
else
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForBlockwiseWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
using
GridwiseBatchNormForwardWithBlockwiseWelford_
=
GridwiseBatchNormForwardWithBlockwiseWelford
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_batchnorm_fwd
=
kernel_batchnorm_forward_with_blockwise_welford
<
GridwiseBatchNormForwardWithBlockwiseWelford_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_batchnorm_fwd
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
pArg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
pArg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
pArg
)
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
XSrcYDstVectorDim
==
0
)
{
if
(
pArg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
yStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
}
else
{
if
(
pArg_
->
xStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
yStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
};
if
(
pArg_
->
bnScaleStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
ScaleSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnBiasStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
BiasSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
ScaleSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
BiasSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnMeanVarStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
MeanVarSrcDstVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
MeanVarSrcDstVectorSize
!=
0
)
return
false
;
bool
is_valid
=
true
;
static_for
<
0
,
NumInvariantDim
,
1
>
{}([
&
](
auto
I
)
{
if
(
pArg_
->
xyLengths_
[
I
]
!=
pArg_
->
bnScaleBiasMeanVarLengths_
[
I
])
is_valid
=
false
;
});
if
(
!
is_valid
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_scale
,
const
void
*
p_bias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
,
double
averageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
y_elementwise_op
,
epsilon
,
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
MeanVarDataType
*>
(
resultSaveMean
),
static_cast
<
MeanVarDataType
*>
(
resultSaveInvVariance
),
averageFactor
,
static_cast
<
MeanVarDataType
*>
(
resultRunningMean
),
static_cast
<
MeanVarDataType
*>
(
resultRunningVariance
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchNormFwdImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XSrcYDstVectorDim_"
<<
XSrcYDstVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_scale_"
<<
ScaleSrcVectorSize
<<
"_bias_"
<<
BiasSrcVectorSize
<<
"_mean_var_"
<<
MeanVarSrcDstVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
4fec5ad3
...
@@ -5,9 +5,8 @@
...
@@ -5,9 +5,8 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <array>
#include "ck/utility/common_header.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
...
@@ -41,7 +40,8 @@ template <typename InDataType,
...
@@ -41,7 +40,8 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
n
umSrcDim
=
Rank
;
static
constexpr
index_t
N
umSrcDim
=
Rank
;
static
constexpr
index_t
n
umDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
index_t
N
umDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
...
@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
const
std
::
array
<
index_t
,
Rank
>&
inStrides
,
int
blkGroupSize
,
int
blkGroupSize
,
int
numBlockTileIteration
)
int
numBlockTileIteration
)
{
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleSrcStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
inStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
...
@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
n
umSrcDim
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
umSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
return
transform_tensor_descriptor
(
one_dim_inDesc
,
...
@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
const
auto
reduceDimLengths
=
generate_tuple
(
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDim
s
{});
[
&
](
auto
I
)
{
return
inLengths
[
NumInvariantDim
+
I
];
},
Number
<
Num
ReduceDim
>
{});
const
auto
invariantDimLengths
=
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDim
s
{});
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
Num
InvariantDim
>
{});
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
inDesc
,
inDesc
,
...
@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
index_t
>&
outLengths
,
static
auto
MakeDst1dDescriptor
(
const
std
::
array
<
index_t
,
NumDstDim
>&
outLengths
,
const
std
::
vector
<
index_t
>&
outStrides
)
const
std
::
array
<
index_t
,
NumDstDim
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstLengths
=
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
generate_tuple
([
&
](
auto
I
)
{
return
outLengths
[
I
];
},
Number
<
NumDstDim
>
{});
const
auto
tupleDstStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
outStrides
[
I
];
},
Number
<
NumDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
n
umDstDim
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
umDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
...
@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return
(
out_grid_desc_m_padded
);
return
(
out_grid_desc_m_padded
);
};
};
static
auto
MakeDst1dDescriptorForBufferSet
(
const
std
::
vector
<
index_t
>&
outLengths
,
static
auto
MakeDst1dDescriptorForBufferSet
(
const
std
::
array
<
index_t
,
NumDstDim
>&
outLengths
,
const
std
::
vector
<
index_t
>&
outStrides
)
const
std
::
array
<
index_t
,
NumDstDim
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstLengths
=
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
generate_tuple
([
&
](
auto
I
)
{
return
outLengths
[
I
];
},
Number
<
NumDstDim
>
{});
const
auto
tupleDstStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
outStrides
[
I
];
},
Number
<
NumDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
n
umDstDim
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
umDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
length
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
length
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
...
@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
vector
<
index_t
>
outStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
InDataType
*
in_dev
,
const
InDataType
*
in_dev
,
...
@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
math
::
integer_least_multiple
(
invariant_total_length
,
BlockSize
)
/
BlockSize
;
math
::
integer_least_multiple
(
invariant_total_length
,
BlockSize
)
/
BlockSize
;
}
}
std
::
vector
<
index_t
>
inLengths_
;
std
::
array
<
index_t
,
Rank
>
inLengths_
;
std
::
vector
<
index_t
>
inStrides_
;
std
::
array
<
index_t
,
Rank
>
inStrides_
;
std
::
vector
<
index_t
>
outLengths_
;
std
::
array
<
index_t
,
NumDstDim
>
outLengths_
;
std
::
vector
<
index_t
>
outStrides_
;
std
::
array
<
index_t
,
NumDstDim
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
AccDataType
beta_
;
AccDataType
beta_
;
...
@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
};
};
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
vector
<
index_t
>
outStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
4fec5ad3
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -34,7 +35,8 @@ template <typename InDataType,
...
@@ -34,7 +35,8 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
...
@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
n
umSrcDim
=
Rank
;
static
constexpr
index_t
N
umSrcDim
=
Rank
;
static
constexpr
index_t
n
umDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
index_t
N
umDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static
constexpr
index_t
M_BlockTileSize
=
BlockSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
BlockSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
1
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
1
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
)
const
std
::
array
<
index_t
,
Rank
>&
inStrides
)
{
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcLengths
=
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleSrcStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
inStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
...
@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
n
umSrcDim
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
umSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
return
transform_tensor_descriptor
(
one_dim_inDesc
,
...
@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
const
auto
reduceDimLengths
=
generate_tuple
(
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDim
s
{});
[
&
](
auto
I
)
{
return
inLengths
[
NumInvariantDim
+
I
];
},
Number
<
Num
ReduceDim
>
{});
const
auto
invariantDimLengths
=
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDim
s
{});
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
Num
InvariantDim
>
{});
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
inDesc
,
inDesc
,
...
@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
index_t
>&
outLengths
,
static
auto
MakeDst1dDescriptor
(
const
std
::
array
<
index_t
,
NumDstDim
>&
outLengths
,
const
std
::
vector
<
index_t
>&
outStrides
)
const
std
::
array
<
index_t
,
NumDstDim
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numDstDim
>
{});
const
auto
tupleDstLengths
=
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numDstDim
>
{});
generate_tuple
([
&
](
auto
I
)
{
return
outLengths
[
I
];
},
Number
<
NumDstDim
>
{});
const
auto
tupleDstStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
outStrides
[
I
];
},
Number
<
NumDstDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
n
umDstDim
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
umDstDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
...
@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
vector
<
index_t
>
outStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
InDataType
*
in_dev
,
const
InDataType
*
in_dev
,
...
@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
M_BlockTileSize
;
M_BlockTileSize
;
}
}
std
::
vector
<
index_t
>
inLengths_
;
std
::
array
<
index_t
,
Rank
>
inLengths_
;
std
::
vector
<
index_t
>
inStrides_
;
std
::
array
<
index_t
,
Rank
>
inStrides_
;
std
::
vector
<
index_t
>
outLengths_
;
std
::
array
<
index_t
,
NumDstDim
>
outLengths_
;
std
::
vector
<
index_t
>
outStrides_
;
std
::
array
<
index_t
,
NumDstDim
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
AccDataType
beta_
;
AccDataType
beta_
;
...
@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
...
@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
};
};
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
vector
<
index_t
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
vector
<
index_t
>
outStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
View file @
4fec5ad3
...
@@ -8,12 +8,9 @@
...
@@ -8,12 +8,9 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
...
@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
using
Reduction
=
DeviceReduceMultiBlock
<
InDataType
,
AccDataType
,
static
constexpr
index_t
NumSrcDim
=
Rank
;
OutDataType
,
static
constexpr
index_t
NumDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
Rank
,
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
NumReduceDim
,
reduce
::
Add
,
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
InElementwiseOp
,
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
AccElementwiseOp
,
InMemoryDataOperationEnum
::
Set
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
false
,
// PropagateNan
const
std
::
vector
<
index_t
>&
inStrides
,
false
,
// OutputIndex
int
blkGroupSize
,
false
,
// HaveIndexInputIfOutputIndex
int
numBlockTileIteration
)
BlockSize
,
{
MThreadClusterSize
,
const
auto
tupleSrcLengths
=
KThreadClusterSize
,
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
Rank
>
{});
MThreadSliceSize
,
const
auto
tupleSrcStrides
=
KThreadSliceSize
,
generate_tuple
([
&
](
auto
I
)
{
return
inStrides
[
I
];
},
Number
<
Rank
>
{});
InSrcVectorDim
,
InSrcVectorSize
,
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
1
>
;
// OutDstVectorSize
const
auto
in_grid_desc_m_k
=
[
&
]()
{
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
(
[
&
](
auto
I
)
{
return
inLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
inLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
OutDataType
,
...
@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
...
@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDstVectorSize
,
OutDstVectorSize
,
true
>
;
true
>
;
struct
Argument
:
public
Reduction
::
Argument
struct
Argument
:
public
Base
Argument
{
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
inStrides
,
...
@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
...
@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType
*
out_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
AccElementwiseOp
acc_elementwise_op
)
:
Reduction
::
Argument
(
inLengths
,
:
alpha_
{
alpha
},
inStrides
,
beta_
{
beta
},
{},
in_dev_
{
in_dev
},
{},
out_dev_
{
out_dev
},
reduceDims
,
in_elementwise_op_
{
in_elementwise_op
},
0.0
f
,
// alpha
acc_elementwise_op_
{
acc_elementwise_op
}
0.0
f
,
// beta
in_dev
,
nullptr
,
out_dev
,
nullptr
,
in_elementwise_op
,
acc_elementwise_op
),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_
(
alpha
),
beta_
(
beta
)
{
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
// << ", numBlockTileIteration= " << this->numBlockTileIteration
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
long_index_t
invariant_total_length
;
// std::endl;
long_index_t
reduce_total_length
;
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
inLengths_
);
if
constexpr
(
NumInvariantDim
==
0
)
invariant_lowest_length_
=
1
;
else
invariant_lowest_length_
=
inLengths_
[
NumInvariantDim
-
1
];
blkGroupSize
=
1
;
numBlockTileIteration
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize
;
}
}
std
::
vector
<
index_t
>
inLengths_
;
std
::
vector
<
index_t
>
inStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
AccDataType
beta_
;
AccDataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
InElementwiseOp
in_elementwise_op_
;
AccElementwiseOp
acc_elementwise_op_
;
index_t
invariant_lowest_length_
;
int
blkGroupSize
;
int
numBlockTileIteration
;
size_t
gridSize
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
in_grid_desc_m_k
=
DeviceSoftmaxImpl
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
out_grid_desc_m_k
=
DeviceSoftmaxImpl
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
bool
sweep_once
=
bool
sweep_once
=
...
@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
...
@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
Reduction
::
IsSupportedArgument
(
p_arg_
)
)
if
constexpr
(
InSrcVectorDim
==
0
)
{
{
return
false
;
if
constexpr
(
NumInvariantDim
==
0
)
}
{
return
false
;
}
else
{
if
(
p_arg_
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
OutDstVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length_
%
InSrcVectorSize
!=
0
)
return
false
;
};
}
else
{
{
if
(
p_arg_
->
inStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
InSrcVectorSize
!=
0
)
return
false
;
};
if
(
p_arg_
->
invariant_lowest_length_
%
OutDstVectorSize
!=
0
)
return
false
;
return
false
;
}
return
true
;
return
true
;
};
};
...
...
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment