Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc70e3e1
Unverified
Commit
dc70e3e1
authored
Nov 01, 2022
by
arai713
Committed by
GitHub
Nov 01, 2022
Browse files
Merge branch 'develop' into gridwise_2d
parents
10947a54
8ee36118
Changes
105
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2664 additions
and
902 deletions
+2664
-902
example/34_batchnorm/batchnorm_forward_nhwc.cpp
example/34_batchnorm/batchnorm_forward_nhwc.cpp
+222
-95
example/34_batchnorm/batchnorm_infer_impl.hpp
example/34_batchnorm/batchnorm_infer_impl.hpp
+27
-15
example/34_batchnorm/batchnorm_infer_nhwc.cpp
example/34_batchnorm/batchnorm_infer_nhwc.cpp
+35
-21
include/ck/ck.hpp
include/ck/ck.hpp
+5
-0
include/ck/tensor_description/tensor_space_filling_curve.hpp
include/ck/tensor_description/tensor_space_filling_curve.hpp
+6
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+36
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+39
-28
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
.../tensor_operation/gpu/device/device_batchnorm_forward.hpp
+13
-8
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
...ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp
...k/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp
+56
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+837
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+25
-19
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+235
-313
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+316
-370
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+7
-24
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+711
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+1
-0
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+82
-0
No files found.
example/34_batchnorm/batchnorm_forward_nhwc.cpp
View file @
dc70e3e1
...
...
@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "batchnorm_forward_impl.hpp"
template
<
typename
InOutDataType
,
typename
AccDataType
>
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
<
InOutDataType
,
AccDataType
>
;
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
static
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
...
...
@@ -44,6 +40,7 @@ class BatchNormFwdArg
int
data_type
=
0
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
bool
use_multiblock_welford
=
false
;
public:
void
show_usage
(
const
char
*
cmd
)
...
...
@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)"
<<
std
::
endl
;
std
::
cout
<<
"Arg5: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
std
::
cout
<<
"Arg6: use multi-block welford (0=n0, 1=yes)"
<<
std
::
endl
;
};
int
processArgs
(
int
argc
,
char
*
argv
[])
...
...
@@ -110,14 +108,15 @@ class BatchNormFwdArg
};
};
if
(
optind
+
5
>
argc
)
if
(
optind
+
6
>
argc
)
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
data_type
=
std
::
atoi
(
argv
[
optind
++
]);
updateMovingAverage
=
std
::
atoi
(
argv
[
optind
++
]);
saveMeanAndInvVariance
=
std
::
atoi
(
argv
[
optind
++
]);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
]));
time_kernel
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
++
]));
use_multiblock_welford
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
]));
if
(
data_type
!=
0
&&
data_type
!=
1
&&
data_type
!=
3
&&
data_type
!=
5
&&
data_type
!=
6
)
return
(
-
1
);
...
...
@@ -128,7 +127,7 @@ class BatchNormFwdArg
using
namespace
ck
;
template
<
typename
InOutDataType
,
typename
AccDataType
>
template
<
typename
InOutDataType
,
typename
AccDataType
,
bool
UseMultiblockInK
>
bool
bnorm_fwd_nhwc_test
(
bool
do_verification
,
int
init_method
,
bool
time_kernel
,
...
...
@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides
.
end
(),
i_scaleBiasMeanVarStrides
.
begin
());
int
result
=
0
;
// used for saving meansquare
DeviceMem
workspace
(
sizeof
(
AccDataType
)
*
2
*
resultSaveMean_ref
.
mDesc
.
GetElementSpaceSize
()
+
128
);
using
PassThroughOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceBatchNormFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwdImpl
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
// ScaleDataType
AccDataType
,
// BiasDataType
AccDataType
,
// MeanVarDataType
PassThroughOp
,
// YElementwiseOp
Rank
,
NumReduceDim
,
UseMultiblockInK
,
256
,
16
,
16
,
1
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
;
void
*
p_tmp_mean
=
workspace
.
GetDeviceBuffer
();
void
*
p_tmp_meansquare
=
static_cast
<
char
*>
(
p_tmp_mean
)
+
(
sizeof
(
AccDataType
)
*
resultSaveMean_ref
.
mDesc
.
GetElementSpaceSize
()
+
63
)
/
64
*
64
;
auto
batchnorm_fwd
=
DeviceBatchNormFwdInstance
{};
result
=
bnorm_fwd
<
InOutDataType
,
AccDataType
,
Rank
,
NumReduceDim
,
false
>
(
time_kernel
,
updateMovingAverage
,
saveMeanAndInvVariance
,
{
0
,
1
,
2
},
auto
argument_ptr
=
batchnorm_fwd
.
MakeArgumentPointer
(
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
averageFactor
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
epsilon
,
PassThroughOp
{},
y_dev
.
GetDeviceBuffer
(),
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
p_tmp_mean
,
p_tmp_meansquare
);
averageFactor
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
if
(
result
<
0
)
if
(
!
batchnorm_fwd
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<<
std
::
endl
;
return
(
false
);
};
size_t
workspace_sz
=
batchnorm_fwd
.
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
batchnorm_fwd
.
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
auto
invoker_ptr
=
batchnorm_fwd
.
MakeInvokerPointer
();
if
(
time_kernel
)
{
float
avg_time
=
0.0
f
;
size_t
num_bytes
=
0
;
size_t
total_length
=
inOutLengths
[
0
]
*
inOutLengths
[
1
]
*
inOutLengths
[
2
]
*
inOutLengths
[
3
];
size_t
invariant_length
=
inOutLengths
[
3
];
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// inputing of x, scale, bias, outputing of y
num_bytes
+=
total_length
*
sizeof
(
InOutDataType
)
*
2
+
invariant_length
*
sizeof
(
AccDataType
)
*
2
;
// outputing of mean, inv-variance
num_bytes
+=
saveMeanAndInvVariance
?
invariant_length
*
sizeof
(
AccDataType
)
*
2
:
0
;
// updating of moving mean, variance
num_bytes
+=
updateMovingAverage
?
invariant_length
*
sizeof
(
AccDataType
)
*
4
:
0
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
else
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
if
(
do_verification
)
{
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
<
InOutDataType
,
AccDataType
>
{};
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThroughOp
>
;
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
{};
auto
argument_ptr_ref
=
batchNormFwd_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
y_ref
.
mData
.
data
(),
0.1
,
// exponentialAverageFactor
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
// resultRunningMean
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
,
// resultRunningVariance
epsilon
,
PassThroughOp
{},
y_ref
.
mData
.
data
(),
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
);
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
if
(
!
batchNormFwd_ref
.
IsSupportedArgument
(
argument_ptr_ref
.
get
()))
{
std
::
cout
<<
"The runtime parameters seems not supported by the BatchNorm
instance, exiting!"
std
::
cout
<<
"The runtime parameters seems not supported by the BatchNorm reference "
"
instance, exiting!"
<<
std
::
endl
;
return
(
-
2
);
return
(
false
);
};
auto
invoker_ptr_ref
=
batchNormFwd_ref
.
MakeInvokerPointer
();
...
...
@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if
(
saveMeanAndInvVariance
)
{
using
ck
::
host_common
::
dumpBufferToFile
;
Tensor
<
AccDataType
>
resultSaveMean
(
scaleBiasMeanVarLengths
);
Tensor
<
AccDataType
>
resultSaveInvVariance
(
scaleBiasMeanVarLengths
);
...
...
@@ -396,7 +464,17 @@ int main(int argc, char* argv[])
if
(
arg
.
data_type
==
0
)
{
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
...
...
@@ -407,7 +485,17 @@ int main(int argc, char* argv[])
}
else
if
(
arg
.
data_type
==
1
)
{
pass
=
bnorm_fwd_nhwc_test
<
float
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
pass
=
bnorm_fwd_nhwc_test
<
float
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
float
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
...
...
@@ -418,7 +506,17 @@ int main(int argc, char* argv[])
}
else
if
(
arg
.
data_type
==
3
)
{
pass
=
bnorm_fwd_nhwc_test
<
int8_t
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
pass
=
bnorm_fwd_nhwc_test
<
int8_t
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
int8_t
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
...
...
@@ -429,7 +527,17 @@ int main(int argc, char* argv[])
}
else
if
(
arg
.
data_type
==
5
)
{
pass
=
bnorm_fwd_nhwc_test
<
ck
::
bhalf_t
,
float
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
pass
=
bnorm_fwd_nhwc_test
<
ck
::
bhalf_t
,
float
,
true
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
ck
::
bhalf_t
,
float
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
...
...
@@ -440,7 +548,17 @@ int main(int argc, char* argv[])
}
else
if
(
arg
.
data_type
==
6
)
{
pass
=
bnorm_fwd_nhwc_test
<
double
,
double
>
(
arg
.
do_verification
,
if
(
arg
.
use_multiblock_welford
)
pass
=
bnorm_fwd_nhwc_test
<
double
,
double
,
true
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
arg
.
updateMovingAverage
,
arg
.
saveMeanAndInvVariance
,
averageFactor
,
epsilon
);
else
pass
=
bnorm_fwd_nhwc_test
<
double
,
double
,
false
>
(
arg
.
do_verification
,
arg
.
init_method
,
arg
.
time_kernel
,
arg
.
inOutLengths
,
...
...
@@ -452,12 +570,21 @@ int main(int argc, char* argv[])
}
else
{
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
>
(
true
,
pass
=
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
true
>
(
true
,
2
,
false
,
// don't time kernel
{
128
,
16
,
16
,
1024
},
{
128
,
16
,
6
,
512
},
true
,
true
,
averageFactor
,
epsilon
);
pass
=
pass
&&
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
false
>
(
true
,
2
,
false
,
// don't time kernel
{
128
,
16
,
3
,
1024
},
true
,
true
,
false
,
averageFactor
,
epsilon
);
};
...
...
example/34_batchnorm/batchnorm_infer_impl.hpp
View file @
dc70e3e1
...
...
@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp"
template
<
typename
InOutDataType
,
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
ck
::
index_t
Rank
,
ck
::
index_t
NumBatchNormReduceDim
,
bool
fastest_dim_is_reduced
=
false
>
...
...
@@ -26,7 +30,9 @@ int bnorm_infer(
const
std
::
array
<
ck
::
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_scale
,
const
void
*
p_bias
,
...
...
@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!"
);
using
DeviceNormalizeInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
InOut
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
,
// x, mean,
ck
::
Tuple
<
X
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
,
// x, mean,
// variance,
// scale,
// bias,
ck
::
Tuple
<
InOut
DataType
>
,
// y
ck
::
Tuple
<
Y
DataType
>
,
// y
NormalizeInInfer
,
Rank
,
2
,
// MPerthread
...
...
@@ -53,14 +59,18 @@ int bnorm_infer(
ck
::
Sequence
<
1
>>
;
// scalarPerVector: y
auto
invariantDims
=
get_invariant_dims
<
Rank
,
NumBatchNormReduceDim
>
(
reduceDims
);
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_scaleBiasMeanVarStrides
{
0
};
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_bnScaleStrides
{
0
};
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_bnBiasStrides
{
0
};
std
::
array
<
ck
::
index_t
,
Rank
>
aligned_bnMeanVarStrides
{
0
};
int
i
=
0
;
for
(
auto
dim
:
invariantDims
)
{
assert
(
xyLengths
[
dim
]
==
bnScaleBiasMeanVarLengths
[
i
]);
aligned_scaleBiasMeanVarStrides
[
dim
]
=
bnScaleBiasMeanVarStrides
[
i
];
aligned_bnScaleStrides
[
dim
]
=
bnScaleStrides
[
i
];
aligned_bnBiasStrides
[
dim
]
=
bnBiasStrides
[
i
];
aligned_bnMeanVarStrides
[
dim
]
=
bnMeanVarStrides
[
i
];
i
++
;
};
...
...
@@ -84,10 +94,10 @@ int bnorm_infer(
auto
argument_ptr1
=
dev_normalize
.
MakeArgumentPointer
(
xyLengths
,
{
xStrides
,
aligned_
scaleBias
MeanVarStrides
,
aligned_
scaleBias
MeanVarStrides
,
aligned_
s
cale
BiasMeanVar
Strides
,
aligned_
scaleBiasMeanVar
Strides
},
aligned_
bn
MeanVarStrides
,
aligned_
bn
MeanVarStrides
,
aligned_
bnS
caleStrides
,
aligned_
bnBias
Strides
},
{
yStrides
},
{
p_x
,
p_estimatedMean
,
p_estimatedVariance
,
p_scale
,
p_bias
},
{
p_y
},
...
...
@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time
+=
invoker_ptr1
->
Run
(
argument_ptr1
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
num_bytes
+=
(
total_length
*
(
1
*
sizeof
(
InOutDataType
)
+
4
*
sizeof
(
AccDataType
))
+
total_length
*
sizeof
(
InOutDataType
));
num_bytes
+=
total_length
*
sizeof
(
XDataType
)
+
invariantLength
*
(
sizeof
(
ScaleDataType
)
+
sizeof
(
BiasDataType
)
+
2
*
sizeof
(
MeanVarDataType
))
+
total_length
*
sizeof
(
YDataType
);
if
(
time_kernel
)
{
...
...
example/34_batchnorm/batchnorm_infer_nhwc.cpp
View file @
dc70e3e1
...
...
@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp"
template
<
typename
InOutDataType
,
typename
AccDataType
>
using
ReferenceBatchNormInferInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
<
InOutDataType
,
AccDataType
>
;
static
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
{
"help"
,
no_argument
,
nullptr
,
'?'
},
...
...
@@ -236,14 +231,23 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int
result
=
0
;
result
=
bnorm_infer
<
InOutDataType
,
AccDataType
,
Rank
,
NumReduceDim
,
false
>
(
time_kernel
,
result
=
bnorm_infer
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
Rank
,
NumReduceDim
,
false
>
(
time_kernel
,
{
0
,
1
,
2
},
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
...
...
@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if
(
do_verification
)
{
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
<
InOutDataType
,
AccDataType
>
{};
using
ReferenceBatchNormInferInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
;
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
{};
auto
argument_ptr_ref
=
batchNormInfer_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
...
...
@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
...
...
include/ck/ck.hpp
View file @
dc70e3e1
...
...
@@ -159,6 +159,11 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/tensor_description/tensor_space_filling_curve.hpp
View file @
dc70e3e1
...
...
@@ -14,7 +14,8 @@ namespace ck {
template
<
typename
TensorLengths
,
typename
DimAccessOrder
,
typename
ScalarsPerAccess
>
// # of scalars per access in each dimension
typename
ScalarsPerAccess
,
bool
SnakeCurved
=
true
>
// # of scalars per access in each dimension
struct
SpaceFillingCurve
{
static
constexpr
index_t
nDim
=
TensorLengths
::
Size
();
...
...
@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
ordered_idx
(
idim
)
=
forward_sweep
[
idim
]
?
ordered_access_idx
[
idim
]
:
ordered_access_lengths
[
idim
]
-
1
-
ordered_access_idx
[
idim
];
ordered_idx
(
idim
)
=
!
SnakeCurved
||
forward_sweep
[
idim
]
?
ordered_access_idx
[
idim
]
:
ordered_access_lengths
[
idim
]
-
1
-
ordered_access_idx
[
idim
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
dc70e3e1
...
...
@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
Number
<
m0
>
{},
Number
<
n0
>
{},
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
...
...
@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
m0
,
n0
,
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
View file @
dc70e3e1
...
...
@@ -24,7 +24,8 @@ template <typename ALayout,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
dc70e3e1
...
...
@@ -7,44 +7,55 @@
#include <vector>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_Gemm1N
,
// Sequence<>
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedGemmSoftmaxGemmPermute
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
O
,
ck
::
index_t
Batch
,
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
,
std
::
vector
<
index_t
>
c_gs_ms_os_strides
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB0
,
ck
::
index_t
StrideB1
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB0
,
ck
::
index_t
BatchStrideB1
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
View file @
dc70e3e1
...
...
@@ -13,31 +13,36 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
struct
DeviceBatchNormFwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
,
double
exponentialAverageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
,
double
epsilon
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
)
=
0
;
void
*
resultRunningVariance
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
>>
;
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
View file @
dc70e3e1
...
...
@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp
0 → 100644
View file @
dc70e3e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedConvFwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// input image
const
void
*
p_b
,
// weight
void
*
p_c
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
dc70e3e1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
dc70e3e1
...
...
@@ -7,46 +7,50 @@
#include <vector>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_Gemm1N
,
// Sequence<>
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceGroupedGemmSoftmaxGemmPermute
:
public
BaseOperator
{
struct
ProblemDesc
{
// Overall problem shape
index_t
M
;
index_t
N
;
index_t
K
;
index_t
O
;
index_t
Batch
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_strides
;
// Stride for A/B0/B1; layout determined by template args
index_t
StrideA
;
index_t
StrideB0
;
index_t
StrideB1
;
index_t
BatchStrideA
;
index_t
BatchStrideB0
;
index_t
BatchStrideB1
;
std
::
vector
<
index_t
>
b0_gs_ns_ks_lengths
;
std
::
vector
<
index_t
>
b0_gs_ns_ks_strides
;
std
::
vector
<
index_t
>
b1_gs_os_ns_lengths
;
std
::
vector
<
index_t
>
b1_gs_os_ns_strides
;
// Lengths and strides for output C
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc1_biases_gs_ms_os_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc1_biases_gs_ms_os_strides
;
};
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
dc70e3e1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
dc70e3e1
...
...
@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
...
...
@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed
.
// TensorSpecialization::Packed
in a traditional sense of "packed" tensor
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
dc70e3e1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
dc70e3e1
...
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
...
...
@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
,
MaskOutUpperTriangle
>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
;
...
...
@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct
C0MatrixMask
{
C0MatrixMask
(
index_t
NRaw
)
:
NRaw_
(
NRaw
)
{}
__host__
__device__
bool
IsUpperTriangle
(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__host__
__device__
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
return
IsUpperTriangle
(
m
,
n
)
||
IsNOutOfBound
(
n
);
}
private:
// index_t MRaw_;
index_t
NRaw_
;
};
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
...
...
@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0MatrixMask
=
conditional_t
<
MaskOutUpperTriangle
,
C0MatrixMask_impl
<
MaskOutUpperTrianglePredicate
>
,
C0MatrixMask_impl
<
MaskDisabledPredicate
>>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
0 → 100644
View file @
dc70e3e1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
dc70e3e1
...
...
@@ -214,6 +214,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K1
,
M1PerThread
,
N1PerThread
,
KPerThread
,
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
0 → 100644
View file @
dc70e3e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
enum
struct
MaskingSpecialization
{
MaskDisabled
,
MaskOutUpperTriangle
};
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
{
switch
(
s
)
{
case
MaskingSpecialization
::
MaskDisabled
:
return
"MaskDisabled"
;
case
MaskingSpecialization
::
MaskOutUpperTriangle
:
return
"MaskOutUpperTriangle"
;
default:
return
"Unrecognized specialization!"
;
}
}
struct
MaskDisabledPredicate
{
__host__
__device__
constexpr
bool
operator
()(
index_t
/*m*/
,
index_t
/*n*/
)
const
{
return
false
;
};
__host__
__device__
constexpr
bool
IsTileSkippable
(
index_t
/*m*/
,
index_t
/*n*/
,
index_t
/*m_tile*/
,
index_t
/*n_tile*/
)
const
{
return
false
;
}
};
struct
MaskOutUpperTrianglePredicate
{
__host__
__device__
constexpr
bool
operator
()(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
constexpr
bool
IsTileSkippable
(
index_t
m
,
index_t
n
,
index_t
m_tile
,
index_t
/*n_tile*/
)
const
{
return
operator
()(
m
+
m_tile
-
1
,
n
);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
{
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__host__
__device__
constexpr
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
return
predicate_
(
m
,
n
)
||
IsNOutOfBound
(
n
);
}
__host__
__device__
constexpr
bool
IsTileSkippable
(
index_t
m
,
index_t
n
,
index_t
m_tile
,
index_t
n_tile
)
const
{
return
predicate_
.
IsTileSkippable
(
m
,
n
,
m_tile
,
n_tile
);
}
private:
// index_t MRaw_;
index_t
NRaw_
;
MaskOutPredicate
predicate_
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment