Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
63b152d6
Commit
63b152d6
authored
Oct 17, 2024
by
danyao12
Browse files
Merge branch 'develop' into ck_tile/fa_bwd_v3
parents
ae2d7d2b
14c3cfb1
Changes
132
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1145 additions
and
185 deletions
+1145
-185
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+12
-1
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+1
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+15
-5
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+266
-0
include/ck_tile/host/convolution_parameter.hpp
include/ck_tile/host/convolution_parameter.hpp
+277
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+14
-1
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+37
-10
include/ck_tile/host/reference/reference_im2col.hpp
include/ck_tile/host/reference/reference_im2col.hpp
+117
-45
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+171
-0
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+77
-14
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+70
-13
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+22
-29
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+8
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+21
-23
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+18
-15
No files found.
include/ck_tile/core/container/array.hpp
View file @
63b152d6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return
!
(
a
==
b
);
return
!
(
a
==
b
);
}
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
std
::
vector
<
X
>&
x
)
{
array
<
T
,
N
>
arr
;
static_for
<
0
,
N
,
1
>
{}([
&
x
,
&
arr
](
auto
i
)
{
arr
(
i
)
=
x
[
i
];
});
return
arr
;
}
template
<
typename
T
,
index_t
N
,
typename
X
>
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
{
{
...
...
include/ck_tile/core/container/thread_buffer.hpp
View file @
63b152d6
...
@@ -58,7 +58,7 @@ struct thread_buffer {
...
@@ -58,7 +58,7 @@ struct thread_buffer {
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
typename
X_
,
template
<
typename
X_
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
...
...
include/ck_tile/host.hpp
View file @
63b152d6
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
...
...
include/ck_tile/host/arg_parser.hpp
View file @
63b152d6
...
@@ -50,12 +50,22 @@ class ArgParser
...
@@ -50,12 +50,22 @@ class ArgParser
}
}
return
*
this
;
return
*
this
;
}
}
void
print
()
void
print
()
const
{
{
// find max key length
std
::
string
::
size_type
max_key_length
=
11
;
for
(
auto
&
key
:
keys
)
{
if
(
max_key_length
<
key
.
length
())
{
max_key_length
=
key
.
length
();
}
}
printf
(
"args:
\n
"
);
printf
(
"args:
\n
"
);
for
(
auto
&
key
:
keys
)
for
(
auto
&
key
:
keys
)
{
{
auto
value
=
input_map
[
key
]
;
auto
value
=
input_map
.
at
(
key
)
;
std
::
vector
<
std
::
string
>
help_text_lines
;
std
::
vector
<
std
::
string
>
help_text_lines
;
size_t
pos
=
0
;
size_t
pos
=
0
;
for
(
size_t
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
for
(
size_t
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
...
@@ -69,8 +79,7 @@ class ArgParser
...
@@ -69,8 +79,7 @@ class ArgParser
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
end
()));
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
end
()));
std
::
string
default_value
=
std
::
string
(
"(default:"
)
+
value
.
value
+
std
::
string
(
")"
);
std
::
string
default_value
=
std
::
string
(
"(default:"
)
+
value
.
value
+
std
::
string
(
")"
);
std
::
cout
<<
std
::
setw
(
1
+
max_key_length
-
value
.
name
.
length
())
<<
"-"
<<
key
std
::
cout
<<
std
::
setw
(
2
)
<<
std
::
setw
(
12
-
value
.
name
.
length
())
<<
"-"
<<
key
<<
std
::
setw
(
4
)
<<
" "
<<
help_text_lines
[
0
]
<<
" "
<<
default_value
<<
std
::
setw
(
4
)
<<
" "
<<
help_text_lines
[
0
]
<<
" "
<<
default_value
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -78,7 +87,8 @@ class ArgParser
...
@@ -78,7 +87,8 @@ class ArgParser
help_next_line
!=
help_text_lines
.
end
();
help_next_line
!=
help_text_lines
.
end
();
++
help_next_line
)
++
help_next_line
)
{
{
std
::
cout
<<
std
::
setw
(
17
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
std
::
cout
<<
std
::
setw
(
1
+
max_key_length
+
4
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
}
}
}
}
}
}
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
conv
{
namespace
detail
{
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
{
return
{
0
,
1
,
3
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
{
return
{
0
,
1
,
4
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
{
return
{
2
,
0
,
3
,
1
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
{
return
{
3
,
0
,
4
,
1
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
else
{
printf
(
"%s
\n
"
,
__func__
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
}
}
// namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template
<
typename
InLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
InLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
InLayout
>
());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template
<
typename
WeiLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
if
(
param
.
G_
!=
1
)
{
throw
std
::
runtime_error
(
"wrong! G != 1"
);
}
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
WeiLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
WeiLayout
>
());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template
<
typename
OutLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
OutLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
OutLayout
>
());
}
}
// namespace conv
}
// namespace ck_tile
include/ck_tile/host/convolution_parameter.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace
ck_tile
{
namespace
conv
{
struct
ConvParam
{
ConvParam
(
ck_tile
::
index_t
n_dim
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
n_batch
,
ck_tile
::
index_t
n_out_channels
,
ck_tile
::
index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
index_t
>&
right_pads
)
:
num_dim_spatial_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_dim
)),
G_
(
static_cast
<
ck_tile
::
long_index_t
>
(
group_count
)),
N_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_batch
)),
K_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_out_channels
)),
C_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_in_channels
)),
filter_spatial_lengths_
(
num_dim_spatial_
),
input_spatial_lengths_
(
num_dim_spatial_
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
num_dim_spatial_
),
conv_filter_dilations_
(
num_dim_spatial_
),
input_left_pads_
(
num_dim_spatial_
),
input_right_pads_
(
num_dim_spatial_
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
filter_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
filters_len
[
i
]);
input_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
input_len
[
i
]);
conv_filter_strides_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
strides
[
i
]);
conv_filter_dilations_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
dilations
[
i
]);
input_left_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
left_pads
[
i
]);
input_right_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
right_pads
[
i
]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ConvParam
(
ck_tile
::
long_index_t
n_dim
,
ck_tile
::
long_index_t
group_count
,
ck_tile
::
long_index_t
n_batch
,
ck_tile
::
long_index_t
n_out_channels
,
ck_tile
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
right_pads
)
:
num_dim_spatial_
(
n_dim
),
G_
(
group_count
),
N_
(
n_batch
),
K_
(
n_out_channels
),
C_
(
n_in_channels
),
filter_spatial_lengths_
(
filters_len
),
input_spatial_lengths_
(
input_len
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
strides
),
conv_filter_dilations_
(
dilations
),
input_left_pads_
(
left_pads
),
input_right_pads_
(
right_pads
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ck_tile
::
long_index_t
num_dim_spatial_
;
ck_tile
::
long_index_t
G_
;
ck_tile
::
long_index_t
N_
;
ck_tile
::
long_index_t
K_
;
ck_tile
::
long_index_t
C_
;
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
GetOutputSpatialLengths
()
const
{
return
output_spatial_lengths_
;
}
std
::
size_t
GetFlops
()
const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G_
*
N_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
next
(
std
::
begin
(
output_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
());
}
template
<
typename
InDataType
>
std
::
size_t
GetInputByte
()
const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G_
*
N_
*
C_
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
std
::
next
(
std
::
begin
(
input_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
>
std
::
size_t
GetWeightByte
()
const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
>
std
::
size_t
GetOutputByte
()
const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G_
*
N_
*
K_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
std
::
size_t
GetByte
()
const
{
return
GetInputByte
<
InDataType
>
()
+
GetWeightByte
<
WeiDataType
>
()
+
GetOutputByte
<
OutDataType
>
();
}
};
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
{
std
::
string
msg
;
msg
+=
"Following arguments (depending on number of spatial dims):
\n
"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)
\n
"
" G, N, K, C,
\n
"
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
" <strides>, (ie Sy, Sx for 2D)
\n
"
" <dilations>, (ie Dy, Dx for 2D)
\n
"
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
;
return
msg
;
}
CK_TILE_HOST
ck_tile
::
conv
::
ConvParam
parse_conv_param
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
{
const
ck_tile
::
long_index_t
G
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
N
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
K
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
C
=
std
::
stol
(
argv
[
arg_idx
++
]);
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
filter_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_strides
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_dilations
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_left_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_right_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
return
ck_tile
::
conv
::
ConvParam
{
num_dim_spatial
,
G
,
N
,
K
,
C
,
filter_spatial_lengths
,
input_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
}
// namespace conv
}
// namespace ck_tile
include/ck_tile/host/host_tensor.hpp
View file @
63b152d6
...
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
...
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
)
{
os
<<
"dim "
<<
desc
.
get_num_of_dimension
()
<<
", "
;
os
<<
"lengths {"
;
LogRange
(
os
,
desc
.
get_lengths
(),
", "
);
os
<<
"}, "
;
os
<<
"strides {"
;
LogRange
(
os
,
desc
.
get_strides
(),
", "
);
os
<<
"}"
;
return
os
;
}
private:
private:
std
::
vector
<
std
::
size_t
>
mLens
;
std
::
vector
<
std
::
size_t
>
mLens
;
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
63b152d6
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
}
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
BDataType
*
B
,
CDataType
*
C
,
CDataType
*
C
,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if
(
row
<
M
&&
col
<
N
)
if
(
row
<
M
&&
col
<
N
)
{
{
AccDataType
acc
=
0.0
;
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
// Adjust indexing based on matrix layout
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
int
a_index
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideA
+
k
:
k
*
strideA
+
row
;
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
}
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
}
}
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
DeviceMem
&
c_device
,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
...
...
include/ck_tile/host/reference/reference_im2col.hpp
View file @
63b152d6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -9,53 +9,125 @@
...
@@ -9,53 +9,125 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
T
>
template
<
typename
InDataType
,
typename
OutDataType
,
index_t
NDimSpatial
>
CK_TILE_HOST
void
reference_im2col
(
HostTensor
<
T
>&
in_mtx_host_ref
,
CK_TILE_HOST
void
reference_im2col
(
const
HostTensor
<
InDataType
>&
in_host
,
const
HostTensor
<
T
>&
in_host
,
HostTensor
<
OutDataType
>&
out_host
,
int
/*N*/
,
const
ck_tile
::
conv
::
ConvParam
&
conv_params
)
int
/*K*/
,
int
C
,
int
/*Y*/
,
int
X
,
int
Hi
,
int
Wi
,
int
Ho
,
int
Wo
,
int
ConvStrideH
,
int
ConvStrideW
,
int
ConvDilationH
,
int
ConvDilationW
,
int
InLeftPadH
,
int
InLeftPadW
,
int
/*InRightPadH*/
,
int
/*InRightPadW*/
)
{
{
int
GemmM
=
in_mtx_host_ref
.
get_lengths
()[
0
];
const
long_index_t
G
=
in_host
.
get_lengths
()[
0
];
int
GemmK
=
in_mtx_host_ref
.
get_lengths
()[
1
];
const
long_index_t
N
=
in_host
.
get_lengths
()[
1
];
const
long_index_t
C
=
in_host
.
get_lengths
()[
2
];
for
(
int
gemm_m
=
0
;
gemm_m
<
GemmM
;
++
gemm_m
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
int
mtmp
=
gemm_m
;
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
0
];
int
n
=
mtmp
/
(
Ho
*
Wo
);
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
mtmp
-=
n
*
Ho
*
Wo
;
long_index_t
row
=
n
*
Wo
+
wo
;
int
ho
=
mtmp
/
Wo
;
long_index_t
column
=
0
;
int
wo
=
mtmp
-
ho
*
Wo
;
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
int
gemm_k
=
0
;
gemm_k
<
GemmK
;
++
gemm_k
)
{
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
0
])
+
int
ktmp
=
gemm_k
;
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
0
])
-
int
y
=
ktmp
/
(
X
*
C
);
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
ktmp
-=
y
*
X
*
C
;
int
x
=
ktmp
/
C
;
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
int
c
=
ktmp
-
x
*
C
;
{
if
(
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
3
])
int
hi
=
y
*
ConvDilationH
+
ho
*
ConvStrideH
-
InLeftPadH
;
{
int
wi
=
x
*
ConvDilationW
+
wo
*
ConvStrideW
-
InLeftPadW
;
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
bool
inbound
=
(
hi
>=
0
&&
hi
<
Hi
&&
wi
>=
0
&&
wi
<
Wi
);
}
column
++
;
in_mtx_host_ref
(
gemm_m
,
gemm_k
)
=
inbound
?
in_host
(
n
,
hi
,
wi
,
c
)
:
0
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
3
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
4
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
long_index_t
Do
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
1
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
z
=
0
;
z
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
long_index_t
>
(
d_o
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
z
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
2
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
2
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
2
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
type_convert
<
std
::
size_t
>
(
di
)
<
in_host
.
get_lengths
()[
3
]
&&
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
4
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
5
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
di
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Do
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
}
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/epilogue.hpp
View file @
63b152d6
...
@@ -3,5 +3,6 @@
...
@@ -3,5 +3,6 @@
#pragma once
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
));
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
63b152d6
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
integer_divide_ceil
(
x_total
,
num_splits
)
)
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
)
;
const
index_t
split_end
=
split_start
+
x_per_split
;
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
ck_tile
::
min
(
origin_end
,
split_end
));
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
63b152d6
...
@@ -6,8 +6,11 @@
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
...
@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
};
struct
FmhaBwd
Common
Dropout
Kargs
struct
FmhaBwdDropout
SeedOffset
{
{
void
init_dropout
(
const
float
p_drop
,
template
<
typename
T
>
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
,
union
ValueOrPointer
const
float
raw_scale
)
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaBwdCommonDropoutKargs
:
FmhaBwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
,
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
,
float
raw_scale
)
{
{
float
p_undrop
=
1.0
-
p_drop
;
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
p_undrop_in_uint8_t
=
...
@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
...
@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop
=
1.0
/
p_undrop
;
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
}
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
};
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
{
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
};
struct
FmhaBwdDeterministicKargs
struct
FmhaBwdDeterministicKargs
{
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
...
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
...
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
...
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
if
constexpr
(
kIsStoreRandval
)
{
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
...
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
...
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
if
constexpr
(
kIsStoreRandval
)
{
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return
FmhaDropout
{
i_batch_
,
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
kargs
.
drop_offset
,
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
kargs
.
p_undrop_in_uint8_t
};
}
}
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
63b152d6
...
@@ -6,8 +6,11 @@
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
...
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_lse
=
0
;
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
};
struct
FmhaFwd
Common
Dropout
Kargs
struct
FmhaFwdDropout
SeedOffset
{
{
void
init_dropout
(
const
float
p_drop
,
template
<
typename
T
>
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
union
ValueOrPointer
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaFwdCommonDropoutKargs
:
FmhaFwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
)
{
{
float
p_undrop
=
1.0
-
p_drop
;
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
}
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
};
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
{
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
ck_tile
::
index_t
batch_stride_randval
=
0
;
...
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
...
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
...
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
...
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
...
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
...
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return
BlockDropout
{
i_batch_
,
return
BlockDropout
{
i_batch_
,
i_nhead_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
kargs
.
drop_offset
,
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
kargs
.
is_store_randval
};
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
63b152d6
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
o_ptr
;
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
num_splits
;
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
};
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_acc_ptr
,
o_ptr
,
o_ptr
,
batch
,
batch
,
max_seqlen_q
,
seqlen_q
,
seqlen_q
,
hdim_v
,
hdim_v
,
num_splits
,
num_splits
,
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
{},
// placeholder for fp8_static_quant args
batch_stride_o
,
batch_stride_lse_acc
,
batch_stride_lse_acc
};
batch_stride_o_acc
,
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
ck_tile
::
index_t
split_stride_o_acc
)
{
{
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_acc_ptr
,
o_ptr
,
o_ptr
,
batch
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
hdim_v
,
num_splits
,
num_splits
,
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
return
kargs
;
}
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
_
)
ck_tile
::
index_t
hdim_v
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
_
,
nhead
_
,
seqlen_q
_
,
hdim_v
_
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
long_index_t
batch_offset_o
=
0
;
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
row_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
query_start
;
batch_offset_lse
=
query_start
;
}
}
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
}
else
else
{
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
}
// for simplicity, batch stride we just modify the pointer
// for simplicity, batch stride we just modify the pointer
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto
o_acc_dram
=
[
&
]()
{
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_
max_
seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
return
transform_tensor_view
(
o_acc_dram_view
,
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_
max_
seqlen_q
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
else
else
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
o_acc_dram_window
,
lse_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
View file @
63b152d6
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
_
)
ck_tile
::
index_t
hdim_v
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
_
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
_
,
kN1
),
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
_
,
nhead
,
batch_size
_
);
batch_size
);
}
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_block
=
blockIdx
.
x
;
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
63b152d6
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
};
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache
};
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
,
batch_stride_o_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_k
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_v
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
hdim_v
,
num_splits
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
,
num_splits
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
long_index_t
batch_offset_o_acc
=
0
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kIsGroupMode
)
if
constexpr
(
kIsGroupMode
)
{
{
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
}
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
stride_o_acc
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
hdim_v
,
1
),
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
63b152d6
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
nhead
*
num_splits
,
batch_size
);
batch_size
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
63b152d6
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
// KT, Reg ->LDS ->Reg
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds
();
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
// Q: HBM ->Reg ->LDS
...
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
63b152d6
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
// KT, Reg ->LDS ->Reg
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds
();
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
auto
q_dram_window
=
...
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
@@ -827,6 +824,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -827,6 +824,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
},
s_acc
,
s_acc
,
bias_s_tile
);
bias_s_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
...
@@ -918,6 +916,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -918,6 +916,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
auto
dp_acc
=
SPGradBlockTileType
{};
...
@@ -927,6 +926,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -927,6 +926,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
auto
ds
=
SPGradBlockTileType
{};
...
@@ -965,6 +965,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -965,6 +965,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
...
@@ -984,6 +985,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -984,6 +985,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 7, SGrad@K^T Gemm4
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
clear_tile
(
dq_acc
);
...
@@ -1005,6 +1007,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -1005,6 +1007,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
__builtin_amdgcn_sched_barrier
(
0
);
// Results Scale
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment