Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
408534d4
Unverified
Commit
408534d4
authored
Aug 09, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Aug 09, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
a8efb3f0
da214a5a
Changes
204
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1271 additions
and
159 deletions
+1271
-159
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+19
-3
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+21
-1
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+4
-1
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+2
-2
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+1
-1
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+1
-1
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+4
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+3
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+59
-57
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+12
-12
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+12
-12
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+13
-13
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
...erence_tensor_operation/cpu/reference_image_to_column.hpp
+50
-49
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
...k/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
+226
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+225
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+81
-3
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
...y/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
+457
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_binary_outelementop_instance.hpp
...ice_grouped_conv_fwd_xdl_binary_outelementop_instance.hpp
+78
-0
No files found.
include/ck/utility/math_v2.hpp
View file @
408534d4
...
...
@@ -839,7 +839,7 @@ inline __device__ T rcp(T x)
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__
expf
(
ck
::
type_convert
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
__
ocml_exp_f32
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
...
...
@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x)
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__
expf
(
x
);
return
__
ocml_exp_f32
(
x
);
};
template
<
>
...
...
include/ck/utility/reduction_operator.hpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -52,11 +52,19 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
bhalf_t
>
(
a_
+
b_
);
}
};
struct
SquaredAdd
...
...
@@ -104,11 +112,19 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
bhalf_t
>
(
a_
*
b_
);
}
};
struct
Max
...
...
include/ck/utility/type_convert.hpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
namespace
ck
{
// Define the common macro for gfx94x models
...
...
@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
408534d4
...
...
@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x)
};
CK_TILE_DEVICE
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp2
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
408534d4
...
...
@@ -835,7 +835,7 @@ CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp2
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -860,7 +860,7 @@ CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp2
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/half.hpp
View file @
408534d4
...
...
@@ -374,7 +374,7 @@ half_t sqrt(half_t x)
};
CK_TILE_DEVICE
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp2
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/math.hpp
View file @
408534d4
...
...
@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__
expf
(
x
);
};
float
exp
(
float
x
)
{
return
__
ocml_exp_f32
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
408534d4
...
...
@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
408534d4
...
...
@@ -231,7 +231,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
...
...
@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
{
constexpr
auto
input_offset_to_spatial
=
3
;
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
ck
::
long_
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
output_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
...
...
@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
long_
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
long_
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
long_
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
...
...
@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
long_
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
...
...
@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
long_
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
for
(
long_
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
long_
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
for
(
long_
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
...
...
@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
for
(
long_index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
...
...
@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
...
...
@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
ck
::
long_
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
long_
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
long_
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
...
...
@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
408534d4
...
...
@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
408534d4
...
...
@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
Tensor
<
InDataType
>&
in_n_c_hi_wi
,
Tensor
<
WeiDataType
>&
wei_k_c_y_x
,
const
Tensor
<
OutDataType
>&
out_n_k_ho_wo
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
std
::
array
<
Tensor
<
InDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
Tensor
<
InDataType
>&
in_n_c_hi_wi
,
Tensor
<
WeiDataType
>&
wei_k_c_y_x
,
const
Tensor
<
OutDataType
>&
out_n_k_ho_wo
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
ck
::
long_
index_t
>
conv_strides_
;
std
::
vector
<
ck
::
long_
index_t
>
conv_dilations_
;
std
::
vector
<
ck
::
long_
index_t
>
in_left_pads_
;
std
::
vector
<
ck
::
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
View file @
408534d4
...
...
@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
...
...
@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
...
...
@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
input_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
...
...
@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
long_
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
long_
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
long_
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
long_
index_t
row
=
n
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
...
...
@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
...
...
@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
long_
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
for
(
long_
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
...
...
@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
ck
::
long_
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
ck
::
long_
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
ck
::
long_
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
output_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
...
...
@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
408534d4
...
...
@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using
MultiplyFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyFastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
A0DataType
,
typename
A1DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_ABScale
<
ALayout
,
BLayout
,
Tuple
<>
,
CLayout
,
A0DataType
,
A1DataType
,
B0DataType
,
B1DataType
,
Tuple
<>
,
CDataType
,
128
,
128
,
128
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmMultipleD_ABScale
<
ALayout
,
BLayout
,
Tuple
<>
,
CLayout
,
A0DataType
,
A1DataType
,
B0DataType
,
B1DataType
,
Tuple
<>
,
CDataType
,
128
,
128
,
128
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
A0DataType
,
f8_t
>
&&
is_same_v
<
B0DataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -315,7 +315,7 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_F
P
16
#ifdef CK_ENABLE_
B
F16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -416,6 +416,57 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_ins
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
...
...
@@ -596,7 +647,7 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef CK_ENABLE_F
P
16
#ifdef CK_ENABLE_
B
F16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
...
...
@@ -653,6 +704,33 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
#ifdef CK_ENABLE_FP16
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_binary_outelementop_instance.hpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
F8
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
ConvFwdOddC
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_binary_outelementop_f8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment