Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8dd7156d
Commit
8dd7156d
authored
Jul 25, 2023
by
ltqin
Browse files
Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask
parents
d5f629e7
b5a3ea2d
Changes
534
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1746 additions
and
324 deletions
+1746
-324
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+0
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+0
-6
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
..._operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
+316
-0
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
...ion/gpu/device/impl/device_multiple_reduce_multiblock.hpp
+9
-9
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
...ion/gpu/device/impl/device_multiple_reduce_threadwise.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+67
-141
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+658
-0
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
.../tensor_operation/gpu/device/impl/device_permute_impl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+68
-45
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
...eration/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
+357
-0
include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp
...sor_operation/gpu/device/impl/device_put_element_impl.hpp
+155
-0
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
...tensor_operation/gpu/device/impl/device_reduce_common.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+17
-9
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+18
-7
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+18
-24
include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp
...evice/impl/device_sparse_embeddings_forward_layernorm.hpp
+45
-62
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+6
-3
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+1
-1
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+1
-1
include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp
...ensor_operation/gpu/device/reduction_operator_mapping.hpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
534 of 534+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
8dd7156d
...
...
@@ -296,12 +296,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
8dd7156d
...
...
@@ -303,12 +303,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
,
ck
::
index_t
InOutVectorSize
>
struct
DeviceIndexPoolBwdImpl
:
public
DeviceIndexPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>
{
using
DInDataType_AutomicAddPreCast
=
conditional_t
<
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
,
DInDataType
,
float
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryConvert
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
loop_step
)
{
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
loop_step
)
{
const
auto
desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
length
));
return
PadDescriptor_M_1d
(
desc_m
,
loop_step
);
}
using
InOutGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
));
using
GridwisePutElementSet
=
GridwisePutElement_1D
<
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InOutVectorSize
>
;
using
GridwisePutElementAtomicAdd
=
GridwisePutElement_1D
<
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType_AutomicAddPreCast
,
PassThrough
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InOutVectorSize
>
;
using
GridwiseCasting
=
GridwiseElementwise_1D
<
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
const
DInDataType_AutomicAddPreCast
*>
,
Tuple
<
DInDataType
*>
,
UnaryConvert
,
InOutVectorSize
,
Sequence
<
InOutVectorSize
>
,
Sequence
<
InOutVectorSize
>>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DOutDataType
*
p_dout
,
const
IndexDataType
*
p_indices
,
DInDataType
*
p_din
,
index_t
dout_length
,
index_t
din_length
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
)
:
p_dout_
{
p_dout
},
p_indices_
{
p_indices
},
p_din_
{
p_din
},
dout_length_raw_
{
dout_length
},
din_length_raw_
{
din_length
},
blockSize_
{
256
},
windowOverlap_
{
false
}
{
for
(
size_t
i
=
0
;
i
<
window_lengths
.
size
();
++
i
)
{
windowOverlap_
|=
window_lengths
.
at
(
i
)
>
window_strides
.
at
(
i
);
}
}
const
DOutDataType
*
p_dout_
;
const
IndexDataType
*
p_indices_
;
DInDataType
*
p_din_
;
index_t
dout_length_raw_
;
index_t
din_length_raw_
;
index_t
blockSize_
;
bool
windowOverlap_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
index_t
loop_step
=
gridSize
*
arg
.
blockSize_
*
InOutVectorSize
;
InOutGrid1dDesc
din_grid_desc
=
MakeDescriptor_M
(
arg
.
din_length_raw_
,
loop_step
);
InOutGrid1dDesc
dout_grid_desc
=
MakeDescriptor_M
(
arg
.
dout_length_raw_
,
loop_step
);
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
if
(
arg
.
windowOverlap_
)
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
else
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementSet
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
}
else
{
if
(
arg
.
windowOverlap_
)
{
if
(
arg
.
p_workspace_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
hip_check_error
(
hipMemsetAsync
(
arg
.
p_workspace_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType_AutomicAddPreCast
),
stream_config
.
stream_id_
));
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType_AutomicAddPreCast
,
PassThrough
>
;
const
auto
cast_kernel
=
kernel_elementwise_1d
<
GridwiseCasting
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
const
DInDataType_AutomicAddPreCast
*>
,
Tuple
<
DInDataType
*>
,
UnaryConvert
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
PassThrough
{});
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
cast_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
ck
::
make_tuple
(
din_grid_desc
),
ck
::
make_tuple
(
din_grid_desc
),
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
arg
.
p_din_
,
UnaryConvert
{});
return
elapsed_time
;
}
else
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementSet
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
0
,
arg
.
din_length_raw_
*
sizeof
(
DInDataType
),
stream_config
.
stream_id_
));
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
bool
needCast
=
pArg_
->
windowOverlap_
&&
!
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
);
if
(
!
needCast
)
return
0
;
else
return
pArg_
->
din_length_raw_
*
sizeof
(
DInDataType_AutomicAddPreCast
);
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
->
din_length_raw_
%
InOutVectorSize
!=
0
||
pArg
->
dout_length_raw_
%
InOutVectorSize
!=
0
)
{
return
false
;
}
return
true
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
const
void
*
p_indices
,
void
*
p_din
,
index_t
dout_length
,
index_t
din_length
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
)
override
{
// Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are
// physical size of the packed tensor
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
static_cast
<
const
IndexDataType
*>
(
p_indices
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_length
,
din_length
,
window_lengths
,
window_strides
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -73,7 +73,7 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
I
)
{
using
OutDataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
flag
=
flag
&&
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
flag
&&
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
});
...
...
@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>&
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>&
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>&
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
betas
,
const
std
::
array
<
double
,
NumReduction
>&
alphas
,
const
std
::
array
<
double
,
NumReduction
>&
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>&
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for
(
size_t
i
=
0
;
i
<
NumReduction
;
i
++
)
{
alpha_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
betas
[
i
]);
alpha_values_
(
i
)
=
static_cast
<
AccDataType
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
static_cast
<
AccDataType
>
(
betas
[
i
]);
};
in_dev_
=
static_cast
<
const
InDataType
*>
(
in_dev
);
...
...
@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>&
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>&
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>&
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
betas
,
const
std
::
array
<
double
,
NumReduction
>&
alphas
,
const
std
::
array
<
double
,
NumReduction
>&
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>&
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for
(
size_t
i
=
0
;
i
<
NumReduction
;
i
++
)
{
alpha_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
betas
[
i
]);
alpha_values_
(
i
)
=
static_cast
<
AccDataType
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
static_cast
<
AccDataType
>
(
betas
[
i
]);
};
in_dev_
=
static_cast
<
const
InDataType
*>
(
in_dev
);
...
...
@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,57 +10,25 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
AccElementwiseOperation
acc_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
acc_elementwise_op
);
};
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
Acc
DataType
,
typename
Compute
DataType
,
typename
YDataType
,
typename
Acc
ElementwiseOperation
,
typename
Y
ElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
...
...
@@ -74,16 +42,18 @@ template <typename XDataType,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
,
Acc
ElementwiseOperation
,
Y
ElementwiseOperation
,
Rank
,
NumReduceDim
>
{
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
...
...
@@ -101,7 +71,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
...
...
@@ -150,10 +119,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
const
auto
inPad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
...
...
@@ -165,52 +133,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
in_grid_desc_m_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnce
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
true
>
;
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
struct
Argument
:
public
BaseArgument
{
...
...
@@ -220,51 +143,48 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
Acc
ElementwiseOperation
acc
_elementwise_op
,
AccDataTyp
e
epsilon
,
Y
ElementwiseOperation
y
_elementwise_op
,
doubl
e
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
epsilon_
(
epsilon
),
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
acc
_elementwise_op_
(
acc
_elementwise_op
)
y
_elementwise_op_
(
y
_elementwise_op
)
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
long_index_t
invariant_
total_
length
;
long_index_t
reduce_
total_
length
;
long_index_t
invariant_length
;
long_index_t
reduce_length
;
std
::
tie
(
invariant_
total_
length
,
reduce_
total_
length
)
=
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
reduce_length
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
gridSize_
=
math
::
integer_divide_ceil
(
invariant_length
,
M_BlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
}
Acc
DataType
epsilon_
;
Compute
DataType
epsilon_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
...
...
@@ -277,9 +197,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
Acc
ElementwiseOperation
acc
_elementwise_op_
;
Y
ElementwiseOperation
y
_elementwise_op_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
...
...
@@ -294,23 +213,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel_main
=
arg
.
isSweeponce_
?
kernel_normalization
<
GridwiseNormalizationSweepOnce
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
>
:
kernel_normalization
<
GridwiseReduceLayernormGeneric
,
XDataType
,
auto
kernel_main
=
NormalizationKernelSelector
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
>
;
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
UseWelford
>
(
arg
.
isSweeponce_
);
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -328,7 +251,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
acc
_elementwise_op_
);
arg
.
y
_elementwise_op_
);
return
(
avg_time
);
};
...
...
@@ -359,6 +282,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
return
false
;
};
}
else
...
...
@@ -368,12 +294,12 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
};
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
{
return
false
;
}
};
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
...
...
@@ -421,14 +347,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInvVar
,
Acc
ElementwiseOperation
acc
_elementwise_op
)
override
Y
ElementwiseOperation
y
_elementwise_op
)
override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
...
...
@@ -442,7 +368,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
betaStrides
,
yStrides
,
reduceDims
,
acc
_elementwise_op
,
y
_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
...
...
@@ -461,8 +387,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
// clang-format off
str
<<
"DeviceNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"
M_C
"
<<
MThreadClusterSize
<<
"_
S
"
<<
M
Thread
Slice
Size
<<
","
;
str
<<
"K_
C
"
<<
K
Thread
Cluster
Size
<<
"_
S
"
<<
KThreadSliceSize
<<
","
;
str
<<
"
Cluster_MK_
"
<<
MThreadClusterSize
<<
"_"
<<
K
Thread
Cluster
Size
<<
","
;
str
<<
"
Slice_M
K_"
<<
M
Thread
Slice
Size
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYSrcVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
ComputeDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
>
__global__
void
kernel_normalizationSplitK1st
(
const
XGridDesc_M_K
x_grid_desc_m_k
,
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
{
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
mean_var_grid_desc_m_kblock
,
num_k_block_tile_iteration
,
p_x_global
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
);
};
template
<
typename
GridwiseWelfordNormalization
,
typename
MeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
>
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
count_grid_desc_m_kblock
,
const
XYGammaBetaGridDesc_M_K
x_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
count_grid_desc_m_kblock
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
epsilon
,
p_mean_global
,
p_variance_global
,
p_welford_count_global
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
y_elementwise_op
);
};
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XYVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
using
MeanVarDataType
=
ComputeDataType
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
kBlockSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
kBlockSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeMeanVarDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
KPerTile
>
static
auto
MakeCountDescriptor_M_K
(
index_t
M
,
index_t
K
)
{
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I0
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_k
,
make_tuple
(
MPerTile
,
KPerTile
),
DoPads
{});
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Kernel1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2CountGridDesc_M_KBlock
=
decltype
(
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
ComputeDataType
,
MeanVarDataType
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYVectorDim
,
YDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
YElementwiseOperation
y_elementwise_op
,
double
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
y_elementwise_op_
(
y_elementwise_op
)
{
epsilon_
=
static_cast
<
ComputeDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
numBlockTileIteration_
=
1
;
while
(
true
)
{
int
testKGridSize
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
// we want the kGridSize_ be not more than 128
if
(
testKGridSize
<=
128
)
break
;
++
numBlockTileIteration_
;
};
kGridSize_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
)
*
kGridSize_
;
// We do not use vector load for mean, var and count
static
constexpr
index_t
K_MeanVarCountBlockTileSize
=
KThreadClusterSize
;
numMeanVarCountIteration_
=
math
::
integer_divide_ceil
(
kGridSize_
,
K_MeanVarCountBlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
kGridSize_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
kGridSize_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
kGridSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_MeanVarCountBlockTileSize
>
(
MRaw_
,
kGridSize_
);
}
ComputeDataType
epsilon_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_var_
;
void
*
p_workspace_count_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
YElementwiseOperation
y_elementwise_op_
;
int
kGridSize_
;
int
numMeanVarCountIteration_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
SrcGridDesc_M_K
x_grid_desc_m_k_
;
SrcGridDesc_M_K
gamma_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
arg
.
p_workspace_mean_
==
nullptr
||
arg
.
p_workspace_var_
==
nullptr
||
arg
.
p_workspace_count_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
MeanVarDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel1
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
kernel1_mean_var_grid_desc_m_kblock_
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
numMeanVarCountIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
kGridSize_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
y_elementwise_op_
);
return
avg_time
;
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// workspace for welford intermediate mean
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
welford_size
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
kGridSize_
*
sizeof
(
int32_t
)
+
64
;
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
int
welford_size
=
pArg_
->
MRaw_
*
pArg_
->
kGridSize_
;
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
welford_size
*
sizeof
(
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
pArg_
->
p_workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_var_
)
+
variance_space_sz
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
return
false
;
}
else
{
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
return
false
;
};
}
else
{
if
(
p_arg_
->
xStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
};
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
false
;
}
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
false
;
}
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
double
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInvVar
,
YElementwiseOperation
y_elementwise_op
)
override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore
=
p_saveMean
;
ignore
=
p_saveInvVar
;
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
yStrides
,
reduceDims
,
y_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationSplitKImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool
2d
_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -20,16 +20,18 @@ namespace device {
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
IndexDataType
,
// enable if OutputIndex == true
typename
ComputeDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OuputIndex
,
bool
Ou
t
putIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
ReduceMThreadClusterSize
,
ck
::
index_t
ReduceKThreadClusterSize
,
ck
::
index_t
ReduceMThreadSliceSize
,
ck
::
index_t
ReduceKThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
:
public
DevicePool2dFwd
<
ReduceOpId
>
struct
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
:
public
DevicePoolFwd
<
4
,
2
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -38,7 +40,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
IndexDataType
=
int32_t
;
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
...
...
@@ -59,12 +62,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
)
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
...
...
@@ -141,9 +144,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}));
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
...
...
@@ -152,15 +153,15 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
int
*
p_out_indices_dev
,
IndexDataType
*
p_out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>&
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>&
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>&
input_right_pads
)
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_strides
,
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
...
...
@@ -190,7 +191,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
int
*
p_out_indices_dev_
;
IndexDataType
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
...
...
@@ -208,7 +209,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
Acc
DataType
,
Compute
DataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
...
...
@@ -224,12 +225,14 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OuputIndex
,
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OutputIndex
,
true
,
// pooling need to return global index
false
,
// don't have index input
InDataType
,
OutDataType
,
Acc
DataType
,
Compute
DataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
...
...
@@ -280,22 +283,42 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
)
override
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
if
(
input_lengths
.
size
()
!=
InOutRank
||
window_lengths
.
size
()
!=
WindowRank
||
input_lengths
.
size
()
!=
InOutRank
||
window_strides
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3} in pool2d so far"
);
index_t
N
=
input_lengths
[
0
];
index_t
C
=
input_lengths
[
1
];
index_t
Hi
=
input_lengths
[
2
];
index_t
Wi
=
input_lengths
[
3
];
index_t
Ho
=
output_lengths
[
2
];
index_t
Wo
=
output_lengths
[
3
];
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
=
{
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
=
{
Ho
,
Wo
};
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
int
*>
(
p_out_indices_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
N
,
C
,
input_spatial_lengths
,
window_
spatial_
lengths
,
window_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
// enable if OutputIndex == true
typename
ComputeDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
:
public
DevicePoolFwd
<
5
,
3
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
index_t
InOutRank
=
5
;
static
constexpr
index_t
WindowRank
=
3
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced.
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
window_spatial_lengths
[
0
];
const
index_t
Y
=
window_spatial_lengths
[
1
];
const
index_t
X
=
window_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
window_strides
[
0
];
const
index_t
ConvStrideH
=
window_strides
[
1
];
const
index_t
ConvStrideW
=
window_strides
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
MRaw
=
N
*
Do
*
Ho
*
Wo
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
Z
*
Y
*
X
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// A[ReduceM, ReduceK]
const
auto
in_grid_desc_n_di_hi_wi_c
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_grid_desc_n_dip_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_di_hi_wi_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_grid_desc_n_z_do_y_ho_x_wo_c
=
transform_tensor_descriptor
(
in_grid_desc_n_dip_hip_wip_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
I1
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
in_grid_desc_n_z_do_y_ho_x_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
C
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
auto
out_grid_desc_reducemraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
*
C
));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
IndexDataType
*
p_out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_strides
,
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
a_grid_desc_m_k_
{},
b_grid_desc_m_
{}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
N
,
C
,
input_spatial_lengths
,
window_spatial_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
invariant_lowest_length_
=
C
;
int32_t
reduceLength
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
]
*
window_spatial_lengths
[
2
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
IndexDataType
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
AccElementwiseOperation
acc_element_op_
;
// for checking vector load/store
ck
::
index_t
invariant_lowest_length_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcOutDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OutputIndex
,
true
,
// pooling need to return global index
false
,
// don't have index input
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
ck
::
index_t
M
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_m_
,
arg
.
in_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
->
invariant_lowest_length_
%
InSrcOutDstVectorSize
!=
0
)
{
return
false
;
}
return
true
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
if
(
input_lengths
.
size
()
!=
InOutRank
||
window_lengths
.
size
()
!=
WindowRank
||
input_lengths
.
size
()
!=
InOutRank
||
window_strides
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
,
4
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3, 4} in pool3d so far"
);
index_t
N
=
input_lengths
[
0
];
index_t
C
=
input_lengths
[
1
];
index_t
Di
=
input_lengths
[
2
];
index_t
Hi
=
input_lengths
[
3
];
index_t
Wi
=
input_lengths
[
4
];
index_t
Do
=
output_lengths
[
2
];
index_t
Ho
=
output_lengths
[
3
];
index_t
Wo
=
output_lengths
[
4
];
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
=
{
Di
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
=
{
Do
,
Ho
,
Wo
};
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
N
,
C
,
input_spatial_lengths
,
window_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_put_element.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
MemOp
,
ck
::
index_t
InVectorSize
>
struct
DevicePutElementImpl
:
public
DevicePutElement
<
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
>
{
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
InVectorSize
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
length
));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
using
InGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
,
1
));
using
GridwisePutElement
=
GridwisePutElement_1D
<
InGrid1dDesc
,
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
,
InVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_input
,
const
IndexDataType
*
p_indices
,
OutDataType
*
p_output
,
index_t
input_length
,
ElementwiseOperation
elementwise_op
)
:
p_input_
{
p_input
},
p_indices_
{
p_indices
},
p_output_
{
p_output
},
input_length_raw_
{
input_length
},
elementwise_op_
{
elementwise_op
},
blockSize_
{
256
}
{
}
const
InDataType
*
p_input_
;
const
IndexDataType
*
p_indices_
;
OutDataType
*
p_output_
;
index_t
input_length_raw_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
InGrid1dDesc
in_grid_desc
=
MakeDescriptor_M
(
arg
.
input_length_raw_
,
gridSize
,
arg
.
blockSize_
);
const
auto
kernel
=
kernel_put_element_1d
<
GridwisePutElement
,
InGrid1dDesc
,
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
in_grid_desc
,
arg
.
p_input_
,
arg
.
p_indices_
,
arg
.
p_output_
,
arg
.
elementwise_op_
);
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
->
input_length_raw_
%
InVectorSize
!=
0
)
{
return
false
;
}
return
true
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_input
,
const
void
*
p_indices
,
void
*
p_output
,
index_t
input_length
,
index_t
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_input
),
static_cast
<
const
IndexDataType
*>
(
p_indices
),
static_cast
<
OutDataType
*>
(
p_output
),
input_length
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -40,8 +40,16 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
@@ -67,7 +75,7 @@ struct DeviceReduceMultiBlock
static
constexpr
bool
use_multiblock
=
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
);
static_assert
(
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
static_assert
(
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
"The OutDataType must support the specified OutMemoryDataOperation!"
);
...
...
@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
const
IndexDataType
*
in_index_dev
,
OutDataType
*
out_dev
,
...
...
@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -28,6 +28,7 @@ template <typename InDataType,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
,
bool
TransformIndexKtoGlobal
,
bool
HaveIndexInputIfOutputIndex
,
index_t
BlockSize
,
index_t
MThreadSliceSize
,
...
...
@@ -35,8 +36,17 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
...
...
@@ -156,8 +166,8 @@ struct DeviceReduceThreadWise
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_index_dev
,
...
...
@@ -251,6 +261,7 @@ struct DeviceReduceThreadWise
const
auto
kernel
=
kernel_reduce_threadwise
<
GridwiseReduce
,
OutputIndex
,
TransformIndexKtoGlobal
,
HaveIndexInput
,
InDataType
,
OutDataType
,
...
...
@@ -332,8 +343,8 @@ struct DeviceReduceThreadWise
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType
,
InElementwiseOp
,
AccElementwiseOp
,
Rank
>
Rank
,
NumReduceDim
>
{
static
constexpr
index_t
kRank
=
Rank
;
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
static
constexpr
index_t
kNumInvariantDim
=
Rank
-
NumReduceDim
;
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumSrcDim
=
Rank
;
...
...
@@ -156,19 +149,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
alpha
,
AccDataTyp
e
beta
,
doubl
e
alpha
,
doubl
e
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
:
alpha_
{
alpha
},
beta_
{
beta
},
in_dev_
{
in_dev
},
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
alpha_
=
static_cast
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
AccDataType
>
(
beta
);
if
(
Rank
!=
inLengths
.
size
()
||
Rank
!=
inStrides
.
size
()
||
NumReduceDim
!=
reduceDims
.
size
())
{
...
...
@@ -286,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
k
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
return
false
;
}
else
{
if
(
arg
.
inStrides_
[
k
NumInvariantDim
-
1
]
!=
1
&&
InSrcVectorSize
!=
1
)
if
(
arg
.
inStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
InSrcVectorSize
!=
1
)
{
return
false
;
}
...
...
@@ -315,7 +309,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
}
// To improve
if
(
k
NumInvariantDim
>
0
&&
arg
.
invariant_lowest_length_
%
OutDstVectorSize
!=
0
)
if
(
NumInvariantDim
>
0
&&
arg
.
invariant_lowest_length_
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
...
...
@@ -336,8 +330,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static
auto
MakeArgument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
AccDataTyp
e
alpha
,
const
AccDataTyp
e
beta
,
doubl
e
alpha
,
doubl
e
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
@@ -375,8 +369,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
@@ -385,8 +379,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
*
static_cast
<
const
AccDataType
*>
(
alpha
)
,
*
static_cast
<
const
AccDataType
*>
(
beta
)
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
s
_forward_layernorm.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -24,16 +24,17 @@ template <typename EmbType,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
EmbElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
DimPerBlock
,
ck
::
index_t
RowPerBlock
,
ck
::
index_t
DimThreadSize
,
ck
::
index_t
RowVectorSize
>
struct
DeviceSparseEmbedding3ForwardLayernorm
:
public
BaseOperator
ck
::
index_t
RowVectorSize
,
ck
::
index_t
NumEmbeddings
>
struct
DeviceSparseEmbeddingsForwardLayernorm
:
public
BaseOperator
{
static
auto
MakeOutputDescriptor
(
const
index_t
index_length
,
const
index_t
rows
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
index_length
,
rows
));
...
...
@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct
Argument
:
public
BaseArgument
{
Argument
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
ck
::
index_t
NumRows
,
const
ck
::
index_t
EmbeddingDim
,
const
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
:
p_out_
(
p_out
),
p_emb_a_
(
p_emb_a
),
p_emb_b_
(
p_emb_b
),
p_emb_c_
(
p_emb_c
),
p_index_a_
(
p_index_a
),
p_index_b_
(
p_index_b
),
p_index_c_
(
p_index_c
),
p_embs_
(
p_embs
),
p_indexs_
(
p_indexs
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
NumRows_
(
NumRows
),
EmbeddingDim_
(
EmbeddingDim
),
IndexLength_
(
IndexLength
),
epsilon_
(
epsilon
)
epsilon_
(
epsilon
),
emb_elementwise_op_
(
emb_elementwise_op
)
{
grid_size_
=
(
IndexLength
+
DimClusterSize
-
1
)
/
DimClusterSize
;
}
OutType
*
p_out_
;
const
EmbType
*
p_emb_a_
;
const
EmbType
*
p_emb_b_
;
const
EmbType
*
p_emb_c_
;
const
IndexType
*
p_index_a_
;
const
IndexType
*
p_index_b_
;
const
IndexType
*
p_index_c_
;
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs_
;
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexs_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
ck
::
index_t
NumRows_
;
ck
::
index_t
EmbeddingDim_
;
ck
::
index_t
IndexLength_
;
AccDataType
epsilon_
;
EmbElementwiseOperation
emb_elementwise_op_
;
size_t
grid_size_
;
};
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
void
*
p_emb_a
,
const
void
*
p_emb_b
,
const
void
*
p_emb_c
,
const
void
*
p_index_a
,
const
void
*
p_index_b
,
const
void
*
p_index_c
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
NumRows
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
{
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
OutType
*>
(
p_out
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_a
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_b
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_c
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_a
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_b
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_c
),
p_embs
,
p_indexs
,
reinterpret_cast
<
const
GammaDataType
*>
(
p_gamma
),
reinterpret_cast
<
const
BetaDataType
*>
(
p_beta
),
NumRows
,
EmbeddingDim
,
IndexLength
,
epsilon
);
epsilon
,
emb_elementwise_op
);
}
using
GridwiseSparseEmbedding
=
GridwiseSparseEmbedding
3
ForwardLayernorm
<
EmbType
,
GridwiseSparseEmbedding
s
ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
MakeOutputDescriptor
(
1
,
1
)),
EmbElementwiseOperation
,
BlockSize
,
DimClusterSize
,
RowClusterSize
,
DimPerBlock
,
RowPerBlock
,
DimThreadSize
,
RowVectorSize
>
;
RowVectorSize
,
NumEmbeddings
>
;
struct
Invoker
:
public
BaseInvoker
{
...
...
@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto
out_desc
=
MakeOutputDescriptor
(
arg
.
IndexLength_
,
arg
.
EmbeddingDim_
);
const
auto
kernel_main
=
kernel_sparse_embedding
3
_forward_layernorm
<
GridwiseSparseEmbedding
,
kernel_sparse_embedding
s
_forward_layernorm
<
GridwiseSparseEmbedding
,
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
out_desc
)
>
;
decltype
(
out_desc
),
EmbElementwiseOperation
,
NumEmbeddings
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
...
...
@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3
(
BlockSize
),
0
,
arg
.
p_out_
,
arg
.
p_emb_a_
,
arg
.
p_emb_b_
,
arg
.
p_emb_c_
,
arg
.
p_index_a_
,
arg
.
p_index_b_
,
arg
.
p_index_c_
,
arg
.
p_embs_
,
arg
.
p_indexs_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
out_desc
,
arg
.
epsilon_
);
arg
.
epsilon_
,
arg
.
emb_elementwise_op_
);
return
(
avg_time
);
}
...
...
@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static
bool
IsSupportedArgument
(
const
Argument
*
p_arg
)
{
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
)
&&
(
p_arg
->
NumRows_
%
DimPerBlock
==
0
)
;
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceSparseEmbedding
3
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
str
<<
"DeviceSparseEmbedding
s
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
DimClusterSize
<<
"x"
<<
RowClusterSize
<<
"_"
<<
DimPerBlock
<<
"x"
<<
RowPerBlock
<<
"_"
<<
DimThreadSize
<<
"x"
<<
RowVectorSize
;
...
...
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -56,7 +56,8 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
...
...
@@ -938,7 +939,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
Prev
1
…
22
23
24
25
26
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment