Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2c1ed8b2
Commit
2c1ed8b2
authored
Jun 20, 2022
by
Anthony Chang
Browse files
Merge remote-tracking branch 'upstream/develop' into gemm-layernorm-4
parents
b86b318b
56adf7e9
Changes
103
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3098 additions
and
197 deletions
+3098
-197
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp
...nvnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp
+427
-0
example/21_gemm_layernorm/CMakeLists.txt
example/21_gemm_layernorm/CMakeLists.txt
+1
-0
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
..._gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
+425
-0
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
+21
-20
example/CMakeLists.txt
example/CMakeLists.txt
+1
-1
include/ck/tensor_description/tensor_adaptor.hpp
include/ck/tensor_description/tensor_adaptor.hpp
+4
-0
include/ck/tensor_description/tensor_descriptor.hpp
include/ck/tensor_description/tensor_descriptor.hpp
+7
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+169
-0
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
...k/tensor_operation/gpu/device/device_5ary_elementwise.hpp
+0
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+7
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+34
-32
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+28
-30
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+8
-15
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+244
-28
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+813
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+52
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+750
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+55
-11
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+28
-27
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+24
-31
No files found.
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp
0 → 100644
View file @
2c1ed8b2
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "conv_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "element_wise_operation.hpp"
#include "device_unary_elementwise.hpp"
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_backward_weight.hpp"
using
InDataType
=
ck
::
bhalf_t
;
using
WeiDataType
=
ck
::
bhalf_t
;
using
OutDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryTypeConvert
=
ck
::
tensor_operation
::
element_wise
::
UnaryTypeConvert
<
ck
::
bhalf_t
,
float
>
;
using
DeviceUnaryElementwiseTypeConvertInstance
=
ck
::
tensor_operation
::
device
::
DeviceUnaryElementwise
<
AccDataType
,
WeiDataType
,
UnaryTypeConvert
,
1
,
4
>
;
static
constexpr
auto
ConvBwdWeightDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Default
;
using
DeviceConvBwdWeightBasePtr
=
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeightPtr
<
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
// clang-format off
template
<
ck
::
index_t
NumDimSpatial
>
using
DeviceConvndBwdWeightInstance_bf16_splitk
=
ck
::
tensor_operation
::
device
::
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
InDataType
,
// InDataType
AccDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
NumDimSpatial
,
// NumDimSpatial
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template
<
ck
::
index_t
NumDimSpatial
>
using
ReferenceConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
NumDimSpatial
>
;
template
<
typename
HostTensorB
,
typename
HostTensorA
,
typename
Functor
>
void
host_elementwise
(
HostTensorB
&
B
,
const
HostTensorA
&
A
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
{
size_t
tensor_size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
{});
std
::
cout
<<
__LINE__
<<
":"
<<
tensor_size
<<
", "
<<
A
.
mData
[
0
]
<<
std
::
endl
;
for
(
std
::
size_t
n
=
0
;
n
<
tensor_size
;
++
n
)
{
B
.
mData
[
n
]
=
functor
(
A
.
mData
[
n
]);
}
}
void
print_use_msg
()
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg3: time kernel (0=n0, 1=yes)
\n
"
<<
"arg4: is show log (0=no, 1=yes)
\n
"
<<
"arg5: split-k : in this example split-k must be larger than 1
\n
"
<<
"arg6: N spatial dimensions (default 2)
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
" N, K, C,
\n
"
<<
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
<<
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
<<
" <strides>, (ie Sy, Sx for 2D)
\n
"
<<
" <dilations>, (ie Dy, Dx for 2D)
\n
"
<<
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
<<
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
<<
std
::
endl
;
}
ck
::
utils
::
conv
::
ConvParams
parse_conv_params
(
int
num_dim_spatial
,
char
*
argv
[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck
::
utils
::
conv
::
ConvParams
params
;
int
arg_idx
=
7
;
params
.
num_dim_spatial_
=
num_dim_spatial
;
params
.
N_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
filter_spatial_lengths_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_spatial_lengths_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_spatial_lengths_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_strides_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_strides_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_dilations_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_dilations_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_left_pads_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_left_pads_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_right_pads_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_right_pads_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
return
params
;
}
DeviceConvBwdWeightBasePtr
get_conv_instance
(
int
num_dim_spatial
)
{
switch
(
num_dim_spatial
)
{
case
3
:
{
return
std
::
make_unique
<
DeviceConvndBwdWeightInstance_bf16_splitk
<
3
>>
();
}
case
2
:
{
return
std
::
make_unique
<
DeviceConvndBwdWeightInstance_bf16_splitk
<
2
>>
();
}
case
1
:
{
return
std
::
make_unique
<
DeviceConvndBwdWeightInstance_bf16_splitk
<
1
>>
();
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
int
num_dim_spatial
=
2
;
int
do_log
=
0
;
int
split_k
=
2
;
ck
::
utils
::
conv
::
ConvParams
params
;
params
.
C_
=
128
;
if
(
argc
==
6
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
}
else
if
(
argc
>
6
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
6
]);
// check args number
int
conv_args
=
3
+
num_dim_spatial
*
6
;
int
cmdline_nargs
=
conv_args
+
7
;
if
(
cmdline_nargs
!=
argc
)
{
print_use_msg
();
exit
(
1
);
}
params
=
parse_conv_params
(
num_dim_spatial
,
argv
);
}
else
if
(
argc
!=
1
)
{
print_use_msg
();
exit
(
1
);
}
if
(
split_k
<=
1
)
{
print_use_msg
();
exit
(
1
);
}
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N_
),
static_cast
<
std
::
size_t
>
(
params
.
C_
)};
input_dims
.
insert
(
std
::
end
(
input_dims
),
std
::
begin
(
params
.
input_spatial_lengths_
),
std
::
end
(
params
.
input_spatial_lengths_
));
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K_
),
static_cast
<
std
::
size_t
>
(
params
.
C_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
params
.
filter_spatial_lengths_
),
std
::
end
(
params
.
filter_spatial_lengths_
));
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N_
),
static_cast
<
std
::
size_t
>
(
params
.
K_
)};
output_dims
.
insert
(
std
::
end
(
output_dims
),
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
));
Tensor
<
InDataType
>
in_n_c_hi_wi
(
ck
::
utils
::
conv
::
get_input_host_tensor_descriptor
(
input_dims
,
num_dim_spatial
));
Tensor
<
WeiDataType
>
wei_k_c_y_x_host_result
(
ck
::
utils
::
conv
::
get_filters_host_tensor_descriptor
(
filter_dims
,
num_dim_spatial
));
Tensor
<
WeiDataType
>
wei_k_c_y_x_device_result
(
ck
::
utils
::
conv
::
get_filters_host_tensor_descriptor
(
filter_dims
,
num_dim_spatial
));
Tensor
<
OutDataType
>
out_n_k_ho_wo
(
ck
::
utils
::
conv
::
get_output_host_tensor_descriptor
(
output_dims
,
num_dim_spatial
));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
2
,
2
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
2
,
2
});
break
;
default:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
// reset input to zero
wei_device_buf
.
SetZero
();
// do GEMM
auto
conv
=
get_conv_instance
(
num_dim_spatial
);
auto
invoker
=
conv
->
MakeInvokerPointer
();
auto
argument
=
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N_
,
params
.
K_
,
params
.
C_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{},
split_k
);
// alloc work space
size_t
bwd_weight_workspace_size
=
conv
->
GetWorkSpaceSize
(
argument
.
get
());
if
(
bwd_weight_workspace_size
<=
0
)
{
print_use_msg
();
exit
(
1
);
}
float
conv_ave_time
=
0.
f
;
DeviceMem
wei_work_space_device_buf
(
bwd_weight_workspace_size
);
wei_work_space_device_buf
.
SetZero
();
conv
->
SetWorkSpacePointer
(
argument
.
get
(),
wei_work_space_device_buf
.
GetDeviceBuffer
());
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
conv_ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
params
.
N_
,
params
.
C_
,
params
.
K_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
);
std
::
size_t
num_btype
=
ck
::
utils
::
conv
::
get_btype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
params
.
N_
,
params
.
C_
,
params
.
K_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
conv_ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
conv_ave_time
;
std
::
cout
<<
"Perf: conv: "
<<
conv_ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
if
(
do_verification
)
{
auto
verify_f
=
[
&
](
const
auto
&
ref_conv
)
{
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi
,
wei_k_c_y_x_host_result
,
out_n_k_ho_wo
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out: "
,
out_n_k_ho_wo
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in_n_c_hi_wi
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_device(after): "
,
wei_k_c_y_x_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_host : "
,
wei_k_c_y_x_host_result
.
mData
,
","
)
<<
std
::
endl
;
}
return
ck
::
utils
::
check_err
(
wei_k_c_y_x_device_result
.
mData
,
wei_k_c_y_x_host_result
.
mData
)
?
0
:
1
;
};
switch
(
num_dim_spatial
)
{
case
3
:
{
auto
ref_conv
=
ReferenceConvBwdWeightInstance
<
3
>
();
verify_f
(
ref_conv
);
break
;
}
case
2
:
{
auto
ref_conv
=
ReferenceConvBwdWeightInstance
<
2
>
();
verify_f
(
ref_conv
);
break
;
}
case
1
:
{
auto
ref_conv
=
ReferenceConvBwdWeightInstance
<
1
>
();
verify_f
(
ref_conv
);
break
;
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
return
0
;
}
example/21_gemm_layernorm/CMakeLists.txt
View file @
2c1ed8b2
add_example_executable
(
example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp
)
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
0 → 100644
View file @
2c1ed8b2
This diff is collapsed.
Click to expand it.
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
View file @
2c1ed8b2
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include "check_err.hpp"
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
...
@@ -15,7 +14,6 @@
...
@@ -15,7 +14,6 @@
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -45,17 +43,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -45,17 +43,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSumOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
ReduceSumOp
=
ck
::
reduce
::
Add
;
using
DxsReduceOp
=
ck
::
Tuple
<
ReduceSumOp
,
ReduceSumOp
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
ReduceSumOp
,
ReduceSumOp
>
;
using
UnaryIdenticElementOp
=
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnaryDivElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryDivide
;
using
UnaryDivElementOp
=
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
true
>
;
using
DxsInElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
UnarySquareElementOp
=
using
DxsOutElementOps
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
using
DxsGlobalMemOp
=
using
DxsGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
...
@@ -70,7 +65,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
...
@@ -70,7 +65,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
,
DxsOutElementOp
,
DxsGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
s
,
DxsOutElementOp
s
,
DxsGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
@@ -143,7 +138,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
...
@@ -143,7 +138,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
mean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
mean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
meanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
meanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
auto
averageOpInst
=
UnaryDivElementOp
{
M
};
auto
averageOpInst
=
UnaryDivElementOp
{
N
};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -157,13 +152,14 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
...
@@ -157,13 +152,14 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto
reduceSumOpInst
=
ReduceSumOp
{};
auto
reduceSumOpInst
=
ReduceSumOp
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
float
mean_acc
=
reduceSumOpInst
.
GetIdentityValue
();
auto
mean_acc
=
reduceSumOpInst
.
GetIdentityValue
<
ReduceAccDataType
>
();
float
square_mean_acc
=
reduceSumOpInst
.
GetIdentityValue
();
auto
square_mean_acc
=
reduceSumOpInst
.
GetIdentityValue
<
ReduceAccDataType
>
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
ReduceAccDataType
c_val
=
ck
::
type_convert
<
float
>
(
c_m_n
(
m
,
n
));
auto
c_val
=
ck
::
type_convert
<
ReduceAccDataType
>
(
c_m_n
(
m
,
n
));
ReduceAccDataType
square_c_val
=
0
;
auto
square_c_val
=
reduceSumOpInst
.
GetIdentityValue
<
ReduceAccDataType
>
();
UnarySquareElementOp
{}(
square_c_val
,
c_val
);
UnarySquareElementOp
{}(
square_c_val
,
c_val
);
reduceSumOpInst
(
mean_acc
,
c_val
);
reduceSumOpInst
(
mean_acc
,
c_val
);
...
@@ -183,7 +179,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
...
@@ -183,7 +179,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
float
out_f32
=
0
;
float
out_f32
=
0
;
layerNormInst
(
out_f32
,
c_m_n
(
m
,
n
),
mean_m
(
m
),
meanSquare_m
(
m
),
gamma_n
(
n
),
beta_n
(
n
));
layerNormInst
(
out_f32
,
static_cast
<
float
>
(
c_m_n
(
m
,
n
)),
static_cast
<
float
>
(
mean_m
(
m
)),
static_cast
<
float
>
(
meanSquare_m
(
m
)),
static_cast
<
float
>
(
gamma_n
(
n
)),
static_cast
<
float
>
(
beta_n
(
n
)));
out_m_n
(
m
,
n
)
=
static_cast
<
out_type
>
(
out_f32
);
out_m_n
(
m
,
n
)
=
static_cast
<
out_type
>
(
out_f32
);
}
}
}
}
...
@@ -267,8 +268,8 @@ int main()
...
@@ -267,8 +268,8 @@ int main()
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()),
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()));
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()));
auto
dxs_in_element_op
=
DxsInElementOp
{};
auto
dxs_in_element_op
=
DxsInElementOp
s
{};
auto
dxs_out_element_op
=
DxsOutElementOp
{
M
,
M
};
auto
dxs_out_element_op
=
DxsOutElementOp
s
{
N
,
N
};
// Prepare GEMM, reduce_mean, reduce_mean_square
// Prepare GEMM, reduce_mean, reduce_mean_square
auto
gemmReduce
=
DeviceGemmReduceInstance
{};
auto
gemmReduce
=
DeviceGemmReduceInstance
{};
...
...
example/CMakeLists.txt
View file @
2c1ed8b2
...
@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
...
@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
03_gemm_bias_relu
)
add_subdirectory
(
03_gemm_bias_relu
)
add_subdirectory
(
04_gemm_
bias_relu_add
)
add_subdirectory
(
04_gemm_
add_add_fastgelu
)
add_subdirectory
(
06_conv2d_fwd_bias_relu
)
add_subdirectory
(
06_conv2d_fwd_bias_relu
)
add_subdirectory
(
07_conv2d_fwd_bias_relu_add
)
add_subdirectory
(
07_conv2d_fwd_bias_relu_add
)
add_subdirectory
(
09_convnd_fwd
)
add_subdirectory
(
09_convnd_fwd
)
...
...
include/ck/tensor_description/tensor_adaptor.hpp
View file @
2c1ed8b2
...
@@ -136,7 +136,11 @@ struct TensorAdaptor
...
@@ -136,7 +136,11 @@ struct TensorAdaptor
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default;
__host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__
__device__
constexpr
TensorAdaptor
()
:
transforms_
{},
element_size_
{}
{}
#endif
__host__
__device__
constexpr
TensorAdaptor
(
const
Transforms
&
transforms
)
__host__
__device__
constexpr
TensorAdaptor
(
const
Transforms
&
transforms
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)}
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)}
...
...
include/ck/tensor_description/tensor_descriptor.hpp
View file @
2c1ed8b2
...
@@ -111,7 +111,14 @@ struct TensorDescriptor
...
@@ -111,7 +111,14 @@ struct TensorDescriptor
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__
__device__
constexpr
TensorDescriptor
()
:
transforms_
{},
element_size_
{},
element_space_size_
{}
{
}
#endif
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
ElementSpaceSize
element_space_size
)
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
0 → 100644
View file @
2c1ed8b2
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v7.hpp"
namespace
ck
{
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template
<
typename
ThreadGroup
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
struct
ThreadGroupTensorSliceTransfer_v7
{
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
}
}
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src_descs
,
src_bufs
,
dst_descs
,
dst_bufs
);
}
}
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
}
}
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_descs
,
iDst
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
View file @
2c1ed8b2
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <sstream>
#include <sstream>
#include "device.hpp"
#include "device.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "common_header.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
2c1ed8b2
...
@@ -15,6 +15,8 @@ struct BaseArgument
...
@@ -15,6 +15,8 @@ struct BaseArgument
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
virtual
~
BaseArgument
()
{}
virtual
~
BaseArgument
()
{}
void
*
p_workspace_
=
nullptr
;
};
};
struct
BaseInvoker
struct
BaseInvoker
...
@@ -42,7 +44,11 @@ struct BaseOperator
...
@@ -42,7 +44,11 @@ struct BaseOperator
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
,
void
*
)
const
{}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
{
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
}
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
2c1ed8b2
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -44,7 +44,7 @@ __global__ void
...
@@ -44,7 +44,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsAccElementwiseOperation
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -126,7 +126,7 @@ template <typename ALayout,
...
@@ -126,7 +126,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -162,12 +162,12 @@ template <typename ALayout,
...
@@ -162,12 +162,12 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation
,
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
Dxs
Reduce
AccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
...
@@ -527,7 +527,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -527,7 +527,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CElementwiseOperation
,
CElementwiseOperation
,
DxsReduceOperation
,
DxsReduceOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DGlobalMemoryDataOperation
,
DGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -587,7 +587,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -587,7 +587,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
index_t
BatchCount
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
...
@@ -645,7 +645,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -645,7 +645,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsAccElementwiseOperation
dxs_out_element_op_
;
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op_
;
};
};
// Invoker
// Invoker
...
@@ -703,7 +703,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -703,7 +703,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -746,7 +746,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -746,7 +746,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -832,7 +832,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -832,7 +832,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
index_t
BatchCount
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
...
@@ -856,27 +856,29 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
...
@@ -856,27 +856,29 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
const
void
*
p_a
,
void
*
p_c
,
const
void
*
p_b
,
DPtrsGlobal
p_dxs
,
void
*
p_c
,
index_t
MRaw
,
void
*
p_dxs
,
index_t
NRaw
,
index_t
MRaw
,
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideA
,
index_t
KRaw
,
index_t
StrideB
,
index_t
StrideA
,
index_t
StrideC
,
index_t
StrideB
,
AElementwiseOperation
a_element_op
,
index_t
StrideC
,
BElementwiseOperation
b_element_op
,
AElementwiseOperation
a_element_op
,
CElementwiseOperation
c_element_op
,
BElementwiseOperation
b_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
CElementwiseOperation
c_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
index_t
BatchCount
)
override
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
index_t
BatchCount
)
override
{
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
p_
dxs
,
dxs
_tuple
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
...
...
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
2c1ed8b2
...
@@ -557,11 +557,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -557,11 +557,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
using
Add
=
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
CDataType
,
CDataType
,
CDataType
>
;
using
Subtract
=
ck
::
tensor_operation
::
element_wise
::
Subtract
;
using
Substract
=
ck
::
tensor_operation
::
binary_element_wise
::
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
Substract
<
CDataType
,
CDataType
,
CDataType
>
;
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
...
@@ -573,19 +571,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -573,19 +571,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AScalarPerVector
,
AScalarPerVector
,
BScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
CScalarPerVector
>
;
using
GridwiseBinSub
s
tract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
using
GridwiseBinSubtract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Sub
s
tract
,
Subtract
,
MPerThread
,
MPerThread
,
AScalarPerVector
,
AScalarPerVector
,
BScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
CScalarPerVector
>
;
const
auto
add_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinAdd
,
const
auto
add_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
...
@@ -593,14 +591,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -593,14 +591,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Add
>
;
Add
>
;
const
auto
sub
s
tract_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinSub
s
tract
,
const
auto
subtract_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinSubtract
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Sub
s
tract
>
;
Subtract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
...
@@ -653,7 +651,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -653,7 +651,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
sub
s
tract_kernel
,
subtract_kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -663,7 +661,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -663,7 +661,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Sub
s
tract
{});
Subtract
{});
ave_time
+=
ave_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -764,7 +762,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -764,7 +762,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
sub
s
tract_kernel
,
subtract_kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -774,7 +772,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -774,7 +772,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Sub
s
tract
{});
Subtract
{});
ave_time
+=
ave_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
2c1ed8b2
...
@@ -460,6 +460,8 @@ struct
...
@@ -460,6 +460,8 @@ struct
using
C0GridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridDescs
{}[
I3
])
>
;
using
C0GridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridDescs
{}[
I3
])
>
;
using
C1GridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridDescs
{}[
I4
])
>
;
using
C1GridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridDescs
{}[
I4
])
>
;
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
<
BlockSize
,
BlockSize
,
...
@@ -522,8 +524,6 @@ struct
...
@@ -522,8 +524,6 @@ struct
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
...
@@ -540,10 +540,7 @@ struct
...
@@ -540,10 +540,7 @@ struct
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
block_2_ctile_map_
{
block_2_ctile_map_
{},
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
M01_
{
M01
},
N01_
{
N01
},
in_element_op_
{
in_element_op
},
in_element_op_
{
in_element_op
},
wei_element_op_
{
wei_element_op
},
wei_element_op_
{
wei_element_op
},
out_element_op_
{
out_element_op
},
out_element_op_
{
out_element_op
},
...
@@ -576,6 +573,8 @@ struct
...
@@ -576,6 +573,8 @@ struct
c0_grid_desc_m_n_
=
descs
[
I3
];
c0_grid_desc_m_n_
=
descs
[
I3
];
c1_grid_desc_m_n_
=
descs
[
I4
];
c1_grid_desc_m_n_
=
descs
[
I4
];
block_2_ctile_map_
=
Block2CTileMap
{
c_grid_desc_m_n_
};
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
...
@@ -618,9 +617,7 @@ struct
...
@@ -618,9 +617,7 @@ struct
typename
GridwiseGemm
::
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
;
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
InElementwiseOperation
in_element_op_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
OutElementwiseOperation
out_element_op_
;
OutElementwiseOperation
out_element_op_
;
...
@@ -723,7 +720,7 @@ struct
...
@@ -723,7 +720,7 @@ struct
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
Default
Block2CTileMap
>
,
Block2CTileMap
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
...
@@ -767,7 +764,7 @@ struct
...
@@ -767,7 +764,7 @@ struct
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
Default
Block2CTileMap
>
,
Block2CTileMap
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
...
@@ -894,8 +891,6 @@ struct
...
@@ -894,8 +891,6 @@ struct
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
};
out_element_op
};
...
@@ -938,8 +933,6 @@ struct
...
@@ -938,8 +933,6 @@ struct
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
);
out_element_op
);
...
...
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
2c1ed8b2
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
1
);
1
);
}
}
// type convert descs
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
4
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
template
<
index_t
Dim
>
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
Dim
>
1
)
{
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
blockSize
);
}
using
TypeConvertFunctor
=
ck
::
tensor_operation
::
element_wise
::
UnaryTypeConvert
<
ck
::
bhalf_t
,
float
>
;
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
<
1
>
({
1
},
{
1
},
1
,
1
));
using
GridwiseUEltwise
=
GridwiseUnaryElementwise_1D
<
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
,
4
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
...
@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
true
,
true
,
true
>
;
true
>
;
using
GridwiseGemmAtomicAddFloatBf16Splitk
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
ABlockLdsM1PerBlock
,
ABlockLdsM0PerBlock
,
ABlockLdsM1Padding
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsN1PerBlock
,
BBlockLdsN0PerBlock
,
BBlockLdsN1Padding
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
true
,
true
>
;
// Argument
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
...
@@ -910,41 +1008,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -910,41 +1008,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
};
// run kernel for bf16 with splitk
const
auto
run_bf16_splitk
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_workspace_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
AccDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
static_cast
<
AccDataType
*>
(
arg
.
p_workspace_
),
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
// kernel for type conversion
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
arg
.
Conv_K_
),
static_cast
<
std
::
size_t
>
(
arg
.
Conv_C_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
arg
.
filter_spatial_lengths_
),
std
::
end
(
arg
.
filter_spatial_lengths_
));
int
tensor_size
=
std
::
accumulate
(
filter_dims
.
begin
(),
filter_dims
.
end
(),
1
,
std
::
multiplies
<
int
>
{});
const
index_t
type_convert_grid_size
=
GridwiseUEltwise
::
CalculateGridSize
(
tensor_size
);
GridDesc_M0
a_grid_desc_m0_
=
MakeDescriptor_M0
<
1
>
({
tensor_size
},
{
1
},
type_convert_grid_size
,
256
);
GridDesc_M0
b_grid_desc_m0_
=
MakeDescriptor_M0
<
1
>
({
tensor_size
},
{
1
},
type_convert_grid_size
,
256
);
if
(
!
GridwiseUEltwise
::
CheckValidity
(
a_grid_desc_m0_
,
b_grid_desc_m0_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseUnaryElementwise_1D has invalid setting"
);
}
// run kernel for type conversion
void
*
p_c_grid_tmp_
=
static_cast
<
void
*>
(
arg
.
p_c_grid_
);
InDataType
*
p_c_grid_tmp_bf16_
=
static_cast
<
InDataType
*>
(
p_c_grid_tmp_
);
const
auto
Run_type_convert
=
[
&
](
const
auto
&
kernel
)
{
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
type_convert_grid_size
),
dim3
(
256
),
0
,
static_cast
<
AccDataType
*>
(
arg
.
p_workspace_
),
p_c_grid_tmp_bf16_
,
a_grid_desc_m0_
,
b_grid_desc_m0_
,
TypeConvertFunctor
{});
return
elapsed_time
;
};
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
{
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
if
(
kbatch
==
1
)
GridwiseGemm
,
{
ADataType
,
// TODO: distiguish A/B datatype
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
CDataType
,
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
ADataType
,
// TODO: distiguish A/B datatype
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CDataType
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
InElementwiseOperation
,
remove_reference_t
<
WeiElementwiseOperation
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
OutElementwiseOperation
,
true
>
;
InElementwiseOperation
,
WeiElementwiseOperation
,
Run
(
kernel
);
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel_type_convert
=
kernel_unary_elementwise_1d
<
GridwiseUEltwise
,
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
>
;
const
auto
kernel_conv
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
ADataType
,
// TODO: distiguish A/B datatype
AccDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
run_bf16_splitk
(
kernel_conv
);
ave_time
+=
Run_type_convert
(
kernel_type_convert
);
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
if
(
kbatch
==
1
)
GridwiseGemm
,
{
ADataType
,
// TODO: distiguish A/B datatype
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
CDataType
,
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
ADataType
,
// TODO: distiguish A/B datatype
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CDataType
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
InElementwiseOperation
,
remove_reference_t
<
WeiElementwiseOperation
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
OutElementwiseOperation
,
false
>
;
InElementwiseOperation
,
WeiElementwiseOperation
,
Run
(
kernel
);
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
ADataType
,
// TODO: distiguish A/B datatype
AccDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
run_bf16_splitk
(
kernel
);
}
}
}
}
}
else
else
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
0 → 100644
View file @
2c1ed8b2
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
0 → 100644
View file @
2c1ed8b2
#pragma once
#include <array>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
ck
::
index_t
NumDTensor
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleD
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
ck
::
index_t
NumDTensor
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmMultipleDPtr
=
std
::
unique_ptr
<
DeviceGemmMultipleD
<
NumDTensor
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
0 → 100644
View file @
2c1ed8b2
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
2c1ed8b2
...
@@ -6,19 +6,18 @@ namespace ck {
...
@@ -6,19 +6,18 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
DPtrsGlobal
,
template
<
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
Dxs
Reduce
AccElementwiseOperation
>
struct
DeviceGemmReduce
:
public
BaseOperator
struct
DeviceGemmReduce
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
DPtrsGlobal
p_dxs
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
...
@@ -29,24 +28,69 @@ struct DeviceGemmReduce : public BaseOperator
...
@@ -29,24 +28,69 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
typename
DPtrsGlobal
,
template
<
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
DPtrsGlobal
,
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
DxsReduceAccElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
void
*
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C1ElementwiseOperation
c1_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsReduceAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C1ElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
2c1ed8b2
...
@@ -32,7 +32,7 @@ template <typename ALayout,
...
@@ -32,7 +32,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -68,12 +68,11 @@ template <typename ALayout,
...
@@ -68,12 +68,11 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
Dxs
Reduce
AccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
@@ -389,7 +388,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -389,7 +388,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation
,
CElementwiseOperation
,
DxsReduceOperation
,
DxsReduceOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
DGlobalMemoryDataOperation
,
DGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -449,7 +448,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -449,7 +448,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -498,7 +497,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -498,7 +497,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsInElementwiseOperation
dxs_in_element_op_
;
DxsAccElementwiseOperation
dxs_out_element_op_
;
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op_
;
};
};
// Invoker
// Invoker
...
@@ -554,7 +553,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -554,7 +553,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -594,7 +593,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -594,7 +593,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
,
Dxs
Reduce
AccElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -669,7 +668,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -669,7 +668,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
)
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -691,27 +690,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
...
@@ -691,27 +690,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
const
void
*
p_a
,
void
*
p_c
,
const
void
*
p_b
,
DPtrsGlobal
p_dxs
,
void
*
p_c
,
index_t
MRaw
,
void
*
p_dxs
,
index_t
NRaw
,
index_t
MRaw
,
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideA
,
index_t
KRaw
,
index_t
StrideB
,
index_t
StrideA
,
index_t
StrideC
,
index_t
StrideB
,
AElementwiseOperation
a_element_op
,
index_t
StrideC
,
BElementwiseOperation
b_element_op
,
AElementwiseOperation
a_element_op
,
CElementwiseOperation
c_element_op
,
BElementwiseOperation
b_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
CElementwiseOperation
c_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
index_t
/* KBatch */
=
1
)
override
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
index_t
/* KBatch */
=
1
)
override
{
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
p_
dxs
,
dxs
_tuple
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
2c1ed8b2
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
{
{
grid_size_
=
0
;
grid_size_
=
0
;
gemm_descs_args
_workspace_
=
nullptr
;
p
_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
...
@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
...
@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
void
*
gemm_descs_args_workspace_
;
index_t
grid_size_
;
index_t
grid_size_
;
};
};
...
@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
...
@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
}
}
hipGetErrorString
(
hipGetErrorString
(
hipMemcpy
(
arg
.
gemm_descs_args
_workspace_
,
hipMemcpy
(
arg
.
p
_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
...
@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
...
@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
CElementwiseOperation
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args
_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
}
else
else
{
{
...
@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
...
@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
CElementwiseOperation
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args
_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
...
@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
gemm_descs_args_workspace_
=
workspace_ptr
;
}
};
};
}
// namespace device
}
// namespace device
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment