Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
10e8be48
Unverified
Commit
10e8be48
authored
Oct 01, 2024
by
M.Emin Ozturk
Committed by
GitHub
Oct 01, 2024
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
b416c877
11b7a4db
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1090 additions
and
156 deletions
+1090
-156
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+29
-29
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+12
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+266
-0
include/ck_tile/host/convolution_parameter.hpp
include/ck_tile/host/convolution_parameter.hpp
+283
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+14
-1
include/ck_tile/host/reference/reference_im2col.hpp
include/ck_tile/host/reference/reference_im2col.hpp
+117
-45
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+22
-29
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+8
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+21
-23
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+6
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+4
-5
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+9
-0
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
...ile/ops/image_to_column/kernel/image_to_column_kernel.hpp
+224
-0
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
...mage_to_column/pipeline/block_image_to_column_problem.hpp
+27
-0
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
...s/image_to_column/pipeline/tile_image_to_column_shape.hpp
+32
-0
No files found.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
10e8be48
...
...
@@ -651,97 +651,97 @@ struct MfmaSelector
static
constexpr
auto
GetMfma
();
template
<
>
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x2xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
...
@@ -751,7 +751,7 @@ struct MfmaSelector
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
...
@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
#endif
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
...
...
include/ck_tile/core/container/array.hpp
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
...
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return
!
(
a
==
b
);
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
std
::
vector
<
X
>&
x
)
{
array
<
T
,
N
>
arr
;
static_for
<
0
,
N
,
1
>
{}([
&
x
,
&
arr
](
auto
i
)
{
arr
(
i
)
=
x
[
i
];
});
return
arr
;
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
{
...
...
include/ck_tile/host.hpp
View file @
10e8be48
...
...
@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
0 → 100644
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
conv
{
namespace
detail
{
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
{
return
{
0
,
1
,
3
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
{
return
{
0
,
1
,
4
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
{
return
{
2
,
0
,
3
,
1
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
{
return
{
3
,
0
,
4
,
1
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
else
{
printf
(
"%s
\n
"
,
__func__
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
}
}
// namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template
<
typename
InLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
InLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
InLayout
>
());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template
<
typename
WeiLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
if
(
param
.
G_
!=
1
)
{
throw
std
::
runtime_error
(
"wrong! G != 1"
);
}
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
WeiLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
WeiLayout
>
());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template
<
typename
OutLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
OutLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
OutLayout
>
());
}
}
// namespace conv
}
// namespace ck_tile
include/ck_tile/host/convolution_parameter.hpp
0 → 100644
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace
ck_tile
{
namespace
conv
{
struct
ConvParam
{
ConvParam
();
ConvParam
(
ck_tile
::
index_t
n_dim
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
n_batch
,
ck_tile
::
index_t
n_out_channels
,
ck_tile
::
index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
index_t
>&
right_pads
)
:
num_dim_spatial_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_dim
)),
G_
(
static_cast
<
ck_tile
::
long_index_t
>
(
group_count
)),
N_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_batch
)),
K_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_out_channels
)),
C_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_in_channels
)),
filter_spatial_lengths_
(
num_dim_spatial_
),
input_spatial_lengths_
(
num_dim_spatial_
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
num_dim_spatial_
),
conv_filter_dilations_
(
num_dim_spatial_
),
input_left_pads_
(
num_dim_spatial_
),
input_right_pads_
(
num_dim_spatial_
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
filter_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
filters_len
[
i
]);
input_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
input_len
[
i
]);
conv_filter_strides_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
strides
[
i
]);
conv_filter_dilations_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
dilations
[
i
]);
input_left_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
left_pads
[
i
]);
input_right_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
right_pads
[
i
]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ConvParam
(
ck_tile
::
long_index_t
n_dim
,
ck_tile
::
long_index_t
group_count
,
ck_tile
::
long_index_t
n_batch
,
ck_tile
::
long_index_t
n_out_channels
,
ck_tile
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
right_pads
)
:
num_dim_spatial_
(
n_dim
),
G_
(
group_count
),
N_
(
n_batch
),
K_
(
n_out_channels
),
C_
(
n_in_channels
),
filter_spatial_lengths_
(
filters_len
),
input_spatial_lengths_
(
input_len
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
strides
),
conv_filter_dilations_
(
dilations
),
input_left_pads_
(
left_pads
),
input_right_pads_
(
right_pads
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ck_tile
::
long_index_t
num_dim_spatial_
;
ck_tile
::
long_index_t
G_
;
ck_tile
::
long_index_t
N_
;
ck_tile
::
long_index_t
K_
;
ck_tile
::
long_index_t
C_
;
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
GetOutputSpatialLengths
()
const
{
return
output_spatial_lengths_
;
}
std
::
size_t
GetFlops
()
const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G_
*
N_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
next
(
std
::
begin
(
output_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
());
}
template
<
typename
InDataType
>
std
::
size_t
GetInputByte
()
const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G_
*
N_
*
C_
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
std
::
next
(
std
::
begin
(
input_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
>
std
::
size_t
GetWeightByte
()
const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
>
std
::
size_t
GetOutputByte
()
const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G_
*
N_
*
K_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
std
::
size_t
GetByte
()
const
{
return
GetInputByte
<
InDataType
>
()
+
GetWeightByte
<
WeiDataType
>
()
+
GetOutputByte
<
OutDataType
>
();
}
};
ConvParam
::
ConvParam
()
:
ConvParam
::
ConvParam
(
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
{
}
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
{
std
::
string
msg
;
msg
+=
"Following arguments (depending on number of spatial dims):
\n
"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)
\n
"
" G, N, K, C,
\n
"
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
" <strides>, (ie Sy, Sx for 2D)
\n
"
" <dilations>, (ie Dy, Dx for 2D)
\n
"
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
;
return
msg
;
}
CK_TILE_HOST
ck_tile
::
conv
::
ConvParam
parse_conv_param
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
{
const
ck_tile
::
long_index_t
G
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
N
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
K
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
C
=
std
::
stol
(
argv
[
arg_idx
++
]);
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
filter_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_strides
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_dilations
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_left_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_right_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
return
ck_tile
::
conv
::
ConvParam
{
num_dim_spatial
,
G
,
N
,
K
,
C
,
filter_spatial_lengths
,
input_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
}
// namespace conv
}
// namespace ck_tile
include/ck_tile/host/host_tensor.hpp
View file @
10e8be48
...
...
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
)
{
os
<<
"dim "
<<
desc
.
get_num_of_dimension
()
<<
", "
;
os
<<
"lengths {"
;
LogRange
(
os
,
desc
.
get_lengths
(),
", "
);
os
<<
"}, "
;
os
<<
"strides {"
;
LogRange
(
os
,
desc
.
get_strides
(),
", "
);
os
<<
"}"
;
return
os
;
}
private:
std
::
vector
<
std
::
size_t
>
mLens
;
...
...
include/ck_tile/host/reference/reference_im2col.hpp
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,53 +9,125 @@
namespace
ck_tile
{
template
<
typename
T
>
CK_TILE_HOST
void
reference_im2col
(
HostTensor
<
T
>&
in_mtx_host_ref
,
const
HostTensor
<
T
>&
in_host
,
int
/*N*/
,
int
/*K*/
,
int
C
,
int
/*Y*/
,
int
X
,
int
Hi
,
int
Wi
,
int
Ho
,
int
Wo
,
int
ConvStrideH
,
int
ConvStrideW
,
int
ConvDilationH
,
int
ConvDilationW
,
int
InLeftPadH
,
int
InLeftPadW
,
int
/*InRightPadH*/
,
int
/*InRightPadW*/
)
template
<
typename
InDataType
,
typename
OutDataType
,
index_t
NDimSpatial
>
CK_TILE_HOST
void
reference_im2col
(
const
HostTensor
<
InDataType
>&
in_host
,
HostTensor
<
OutDataType
>&
out_host
,
const
ck_tile
::
conv
::
ConvParam
&
conv_params
)
{
int
GemmM
=
in_mtx_host_ref
.
get_lengths
()[
0
];
int
GemmK
=
in_mtx_host_ref
.
get_lengths
()[
1
];
const
long_index_t
G
=
in_host
.
get_lengths
()[
0
];
const
long_index_t
N
=
in_host
.
get_lengths
()[
1
];
const
long_index_t
C
=
in_host
.
get_lengths
()[
2
];
for
(
int
gemm_m
=
0
;
gemm_m
<
GemmM
;
++
gemm_m
)
if
constexpr
(
NDimSpatial
==
1
)
{
int
mtmp
=
gemm_m
;
int
n
=
mtmp
/
(
Ho
*
Wo
);
mtmp
-=
n
*
Ho
*
Wo
;
int
ho
=
mtmp
/
Wo
;
int
wo
=
mtmp
-
ho
*
Wo
;
for
(
int
gemm_k
=
0
;
gemm_k
<
GemmK
;
++
gemm_k
)
{
int
ktmp
=
gemm_k
;
int
y
=
ktmp
/
(
X
*
C
);
ktmp
-=
y
*
X
*
C
;
int
x
=
ktmp
/
C
;
int
c
=
ktmp
-
x
*
C
;
int
hi
=
y
*
ConvDilationH
+
ho
*
ConvStrideH
-
InLeftPadH
;
int
wi
=
x
*
ConvDilationW
+
wo
*
ConvStrideW
-
InLeftPadW
;
bool
inbound
=
(
hi
>=
0
&&
hi
<
Hi
&&
wi
>=
0
&&
wi
<
Wi
);
in_mtx_host_ref
(
gemm_m
,
gemm_k
)
=
inbound
?
in_host
(
n
,
hi
,
wi
,
c
)
:
0
;
}
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
long_index_t
row
=
n
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
3
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
3
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
4
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
long_index_t
Do
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
1
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
z
=
0
;
z
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
long_index_t
>
(
d_o
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
z
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
2
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
2
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
2
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
type_convert
<
std
::
size_t
>
(
di
)
<
in_host
.
get_lengths
()[
3
]
&&
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
4
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
5
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
di
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Do
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
10e8be48
...
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
integer_divide_ceil
(
x_total
,
num_splits
)
)
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
)
;
const
index_t
split_end
=
split_start
+
x_per_split
;
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
10e8be48
...
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
...
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
...
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
};
struct
GroupModeKargs
...
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
...
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
seqlen_q
,
hdim_v
,
num_splits
,
...
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
,
batch_stride_lse_acc
};
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
...
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
...
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
num_splits
,
...
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
...
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
hdim_v
_
)
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
return
TilePartitioner
::
GridSize
(
batch_size
_
,
nhead
_
,
seqlen_q
_
,
hdim_v
_
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
...
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
row_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
...
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
// for simplicity, batch stride we just modify the pointer
...
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
...
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_
max_
seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_
max_
seqlen_q
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
}
else
...
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
View file @
10e8be48
...
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
hdim_v
_
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
_
,
kN1
),
nhead
_
,
batch_size
_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
10e8be48
...
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
...
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
};
struct
GroupModeKargs
...
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
...
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
,
batch_stride_o_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
,
// only used for paged-kvcache
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
...
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
...
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
hdim_v
,
num_splits
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
,
num_splits
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_o_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
...
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
stride_o_acc
;
// get real # queries & # keys under group mode
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
...
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
hdim_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
10e8be48
...
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
batch_size
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
10e8be48
...
...
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
s_acc
,
bias_s_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
...
...
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
...
...
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
...
...
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
...
...
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
...
...
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
__builtin_amdgcn_sched_barrier
(
0
);
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
10e8be48
...
...
@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
0
>
()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
...
...
@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
1
>
()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
...
...
@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
2
>
()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
...
...
@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
3
>
()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
...
...
@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
4
>
()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
10e8be48
...
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
// lse_acc tile in LDS
...
...
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
const
index_t
padded_
max_
seqlen_q
=
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
kM0
;
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
{
...
...
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
move_tile_window
(
o_acc_dram_window
,
{
padded_
max_
seqlen_q
,
0
});
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
}
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
...
...
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
return
operator
()(
lse_acc_dram_block_window
,
...
...
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
num_splits
,
max_
seqlen_q
,
seqlen_q
,
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
10e8be48
...
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if
masked and
no work to do
.
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
...
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
...
...
include/ck_tile/ops/image_to_column.hpp
0 → 100644
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
0 → 100644
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
>
struct
ImageToColumn
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I3
=
number
<
3
>
{};
static
constexpr
auto
I4
=
number
<
4
>
{};
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
InDataType
=
remove_cvref_t
<
typename
Problem
::
InDataType
>
;
using
OutDataType
=
remove_cvref_t
<
typename
Problem
::
OutDataType
>
;
static
constexpr
index_t
NDimSpatial
=
Problem
::
NDimSpatial
;
static
constexpr
index_t
AligmentIn
=
Problem
::
AligmentIn
;
static
constexpr
index_t
AligmentOut
=
Problem
::
AligmentOut
;
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
static
constexpr
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
index_t
kKPerBlock
=
Problem
::
BlockShape
::
kKPerBlock
;
struct
Kargs
{
const
void
*
p_in
;
void
*
p_out
;
const
long_index_t
G
;
const
long_index_t
N
;
const
long_index_t
C
;
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
void
*
p_in
,
void
*
p_out
,
const
long_index_t
G
,
const
long_index_t
N
,
const
long_index_t
C
,
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
,
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
,
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
,
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
)
{
return
Kargs
{
p_in
,
p_out
,
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
GemmM
,
index_t
GemmK
,
index_t
Batch
)
{
return
dim3
(
integer_divide_ceil
(
GemmM
,
kMPerBlock
),
integer_divide_ceil
(
GemmK
,
kKPerBlock
),
Batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
CK_TILE_DEVICE
auto
MakeImageMKDesc
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
kargs
.
N
,
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
C
),
make_tuple
(
kargs
.
image_g_n_c_wis_strides
[
I1
],
kargs
.
image_g_n_c_wis_strides
[
I3
],
kargs
.
image_g_n_c_wis_strides
[
I4
],
kargs
.
image_g_n_c_wis_strides
[
I2
]),
number
<
AligmentIn
>
{},
I1
);
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_left_pads
[
I0
],
kargs
.
input_right_pads
[
I0
]),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
input_left_pads
[
I1
],
kargs
.
input_right_pads
[
I1
]),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I0
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I0
],
kargs
.
conv_filter_strides
[
I0
])),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
output_spatial_lengths
[
I1
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I1
],
kargs
.
conv_filter_strides
[
I1
])),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{},
sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
N
,
kargs
.
output_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I1
])),
make_merge_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
C
))),
make_tuple
(
sequence
<
0
,
2
,
4
>
{},
sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
CK_TILE_DEVICE
auto
CalculateMKDims
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
index_t
M
=
kargs
.
N
*
static_cast
<
index_t
>
(
kargs
.
output_spatial_lengths
[
I0
]
*
kargs
.
output_spatial_lengths
[
I1
]);
const
index_t
K
=
kargs
.
C
*
static_cast
<
index_t
>
(
kargs
.
filter_spatial_lengths
[
I0
]
*
kargs
.
filter_spatial_lengths
[
I1
]);
return
make_tuple
(
M
,
K
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBlockTileDistribution
()
{
using
P
=
typename
Problem
::
BlockShape
;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
P
::
kMWarpPerBlock
,
P
::
kMThreadPerWarp
,
P
::
kMPerThread
>
,
sequence
<
P
::
kKWarpPerBlock
,
P
::
kKThreadPerWarp
,
P
::
kKPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
2
>>
{});
}
CK_TILE_DEVICE
void
ConvTensorRearrange
(
const
Kargs
&
kargs
)
const
{
const
auto
[
M
,
K
]
=
CalculateMKDims
(
kargs
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kMPerBlock
);
const
index_t
iK
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kKPerBlock
);
const
index_t
iBatch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
in_offset
=
iBatch
*
kargs
.
image_g_n_c_wis_strides
[
I0
];
const
auto
out_offset
=
iBatch
*
kargs
.
gemm_g_m_k_strides
[
I0
];
const
auto
image_m_k
=
make_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
InDataType
*>
(
kargs
.
p_in
)
+
in_offset
,
MakeImageMKDesc
(
kargs
));
const
auto
gemm_m_k
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
OutDataType
*>
(
kargs
.
p_out
)
+
out_offset
,
make_tuple
(
M
,
K
),
make_tuple
(
kargs
.
gemm_g_m_k_strides
[
I1
],
kargs
.
gemm_g_m_k_strides
[
I2
]),
number
<
AligmentOut
>
{},
I1
);
const
auto
image_m_k_padded
=
pad_tensor_view
(
image_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
const
auto
gemm_m_k_padded
=
pad_tensor_view
(
gemm_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
constexpr
auto
dstr
=
MakeBlockTileDistribution
();
const
auto
image_tile
=
make_tile_window
(
image_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
auto
gemm_tile
=
make_tile_window
(
gemm_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
// load from Global
const
auto
loaded_tile
=
load_tile
(
image_tile
);
// save to Global
store_tile
(
gemm_tile
,
loaded_tile
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
&
kargs
)
const
{
ConvTensorRearrange
(
kargs
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
0 → 100644
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
InDataType_
,
typename
OutDataType_
,
typename
BlockShape_
,
index_t
NDimSpatial_
,
index_t
AligmentIn_
,
index_t
AligmentOut_
>
struct
BlockImageToColumnProblem
{
using
InDataType
=
remove_cvref_t
<
InDataType_
>
;
using
OutDataType
=
remove_cvref_t
<
OutDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
index_t
NDimSpatial
=
NDimSpatial_
;
static
constexpr
index_t
AligmentIn
=
AligmentIn_
;
static
constexpr
index_t
AligmentOut
=
AligmentOut_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
0 → 100644
View file @
10e8be48
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ThreadTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
BlockTile
>
// Sequence<...
struct
TileImageToColumnShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kKThreadPerWarp
=
kKPerWarp
/
kKPerThread
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kKWarpPerBlock
=
kKPerBlock
/
kKPerWarp
;
static
constexpr
index_t
kBlockSize
=
warpSize
*
kMWarpPerBlock
*
kKWarpPerBlock
;
};
}
// namespace ck_tile
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment