Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f0bbc5db
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "546d736d66f48b2a2536d0bedd71823e88100b04"
Commit
f0bbc5db
authored
Feb 13, 2025
by
Bartlomiej Kocot
Browse files
[CK TILE] GEMM with packed i4
parent
0e5e29c4
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
715 additions
and
187 deletions
+715
-187
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+1
-0
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+16
-1
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+98
-18
example/ck_tile/03_gemm/universal_gemm_pk_int4.cpp
example/ck_tile/03_gemm/universal_gemm_pk_int4.cpp
+308
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+4
-2
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+2
-2
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+3
-2
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+2
-0
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+21
-0
include/ck_tile/core/numeric/int8.hpp
include/ck_tile/core/numeric/int8.hpp
+2
-1
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+1
-0
include/ck_tile/core/numeric/pk_int4.hpp
include/ck_tile/core/numeric/pk_int4.hpp
+15
-2
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+17
-20
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+88
-44
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+13
-9
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+37
-28
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+28
-19
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+34
-23
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+16
-14
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+9
-2
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
f0bbc5db
...
@@ -3,3 +3,4 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
...
@@ -3,3 +3,4 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options
(
tile_example_gemm_universal PRIVATE
target_compile_options
(
tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
-mllvm -enable-noalias-to-md-conversion=0
)
)
add_executable
(
tile_example_gemm_universal_pk_int4 EXCLUDE_FROM_ALL universal_gemm_pk_int4.cpp
)
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
f0bbc5db
...
@@ -35,7 +35,7 @@
...
@@ -35,7 +35,7 @@
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
#endif
template
<
typename
DataType
>
template
<
typename
ADataType
,
typename
BDataType
=
ADataType
,
typename
CDataType
=
A
DataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
template
<
>
template
<
>
...
@@ -75,6 +75,15 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
...
@@ -75,6 +75,15 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
using
CDataType
=
ck_tile
::
half_t
;
using
CDataType
=
ck_tile
::
half_t
;
};
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
half_t
,
ck_tile
::
pk_int4_t
,
ck_tile
::
half_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
pk_int4_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
typename
T
>
template
<
typename
T
>
struct
DataTypeTraits
;
struct
DataTypeTraits
;
...
@@ -114,6 +123,12 @@ struct DataTypeTraits<ck_tile::bf8_t>
...
@@ -114,6 +123,12 @@ struct DataTypeTraits<ck_tile::bf8_t>
static
constexpr
const
char
*
name
=
"bf8"
;
static
constexpr
const
char
*
name
=
"bf8"
;
};
};
template
<
>
struct
DataTypeTraits
<
ck_tile
::
pk_int4_t
>
{
static
constexpr
const
char
*
name
=
"pk_int4_t"
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
f0bbc5db
...
@@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
...
@@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
// Use higher threshold
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
}
template
<
typename
Tensor
>
void
permute_tensor_b
(
Tensor
&
tensor
)
{
const
ck_tile
::
index_t
K
=
tensor
.
get_length
(
0
);
const
ck_tile
::
index_t
N
=
tensor
.
get_length
(
1
);
// vector pk_i4x4 permute
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
{
int8_t
input
[
8
];
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int8_t
i4x2
=
tensor
(
j
+
k
*
2
,
i
)
.
data
;
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
// permute 01234567->20643175
{
int8_t
hi
=
input
[
2
];
int8_t
lo
=
input
[
0
];
int8_t
i4x2
=
(
hi
<<
4
)
|
lo
;
tensor
(
j
+
0
,
i
)
=
i4x2
;
}
{
int8_t
hi
=
input
[
6
];
int8_t
lo
=
input
[
4
];
int8_t
i4x2
=
(
hi
<<
4
)
|
lo
;
tensor
(
j
+
2
,
i
)
=
i4x2
;
}
{
int8_t
hi
=
input
[
3
];
int8_t
lo
=
input
[
1
];
int8_t
i4x2
=
(
hi
<<
4
)
|
lo
;
tensor
(
j
+
4
,
i
)
=
i4x2
;
}
{
int8_t
hi
=
input
[
7
];
int8_t
lo
=
input
[
5
];
int8_t
i4x2
=
(
hi
<<
4
)
|
lo
;
tensor
(
j
+
6
,
i
)
=
i4x2
;
}
}
}
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
...
@@ -83,7 +137,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -83,7 +137,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return
ave_time
;
return
ave_time
;
}
}
template
<
typename
PrecType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
=
ADataType
,
typename
CDataType
=
ADataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_gemm_example_with_layouts
(
int
argc
,
int
run_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
const
ALayout
a_layout
=
ALayout
{},
...
@@ -94,10 +153,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -94,10 +153,9 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
ADataType
,
BDataType
,
CDataType
>::
AccDataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
constexpr
ck_tile
::
index_t
PackedSizeA
=
ck_tile
::
numeric_traits
<
ADataType
>::
PackedSize
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
constexpr
ck_tile
::
index_t
PackedSizeB
=
ck_tile
::
numeric_traits
<
BDataType
>::
PackedSize
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
@@ -107,10 +165,10 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -107,10 +165,10 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
init_method
=
arg_parser
.
get_int
(
"init"
);
ck_tile
::
index_t
init_method
=
arg_parser
.
get_int
(
"init"
);
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
...
@@ -123,16 +181,23 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -123,16 +181,23 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
if
(
init_method
==
0
)
{
if
(
init_method
==
0
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
else
if
(
init_method
==
1
)
{
}
else
if
(
init_method
==
1
)
{
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
if
(
init_method
==
2
)
{
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1
)}(
a_m_k
);
ck_tile
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1
)}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1
)}(
b_k_n
);
ck_tile
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1
)}(
b_k_n
);
}
else
{
}
else
{
a_m_k
.
SetZero
();
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
b_k_n
.
SetZero
();
}
}
...
@@ -142,7 +207,17 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -142,7 +207,17 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
if
constexpr
(
std
::
is_same_v
<
BDataType
,
ck_tile
::
pk_int4_t
>
)
{
// Permute data for device implementation
ck_tile
::
HostTensor
<
BDataType
>
b_k_n_dev
=
b_k_n
;
permute_tensor_b
(
b_k_n_dev
);
b_k_n_dev_buf
.
ToDevice
(
b_k_n_dev
.
data
());
}
else
{
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
}
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
...
@@ -188,6 +263,11 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -188,6 +263,11 @@ int run_gemm_example_with_layouts(int argc,
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
if
constexpr
(
std
::
is_same_v
<
BDataType
,
ck_tile
::
pk_int4_t
>
)
{
// Restore input for B for gpu reference
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
}
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
...
@@ -198,17 +278,17 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -198,17 +278,17 @@ int run_gemm_example_with_layouts(int argc,
BDataType
*
d_B
;
BDataType
*
d_B
;
CDataType
*
d_C
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
)
/
PackedSizeA
));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
)
/
PackedSizeB
));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
a_m_k_dev_buf
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
M
*
K
*
sizeof
(
ADataType
)
/
PackedSizeA
,
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
b_k_n_dev_buf
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
N
*
K
*
sizeof
(
BDataType
)
/
PackedSizeB
,
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
ck_tile
::
reference_gemm_gpu
<
ADataType
,
ck_tile
::
reference_gemm_gpu
<
ADataType
,
...
...
example/ck_tile/03_gemm/universal_gemm_pk_int4.cpp
0 → 100644
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
4
;
constexpr
ck_tile
::
index_t
N_Warp
=
1
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
// ===============================================
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
UNIVERSAL_GEMM_PIPELINE
<
GemmPipelineProblem
>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
float
ave_time
{
0
};
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
scheduler
=
GEMM_PIPELINE_SCHEDULER
;
using
UniversalGemmProblem
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
GemmUniversalTraits
,
scheduler
,
has_hot_loop_v
,
tail_number_v
>
;
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
};
if
(
has_hot_loop
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
else
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
#endif
}
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
<<
"
\"
is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
ave_time
;
}
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
,
ck_tile
::
pk_int4_t
,
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
,
ck_tile
::
pk_int4_t
,
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
...
@@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
pk_int4_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
rtn_type
=
thread_buffer
<
T
,
N
>
;
using
rtn_type
=
thread_buffer
<
T
,
N
>
;
...
...
include/ck_tile/core/container/thread_buffer.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -156,7 +156,7 @@ struct vector_traits;
...
@@ -156,7 +156,7 @@ struct vector_traits;
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
thread_buffer
<
T
,
N
>>
struct
vector_traits
<
thread_buffer
<
T
,
N
>>
{
{
using
scalar_type
=
T
;
using
scalar_type
=
std
::
conditional_t
<
std
::
is_same_v
<
T
,
pk_int4_t
>
,
int8_t
,
T
>
;
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
f0bbc5db
...
@@ -382,8 +382,9 @@ struct numeric_traits;
...
@@ -382,8 +382,9 @@ struct numeric_traits;
template
<
>
template
<
>
struct
numeric_traits
<
bfloat16_t
>
struct
numeric_traits
<
bfloat16_t
>
{
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
PackedSize
=
1
;
};
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
f0bbc5db
...
@@ -225,6 +225,7 @@ struct numeric_traits<fp8_t>
...
@@ -225,6 +225,7 @@ struct numeric_traits<fp8_t>
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_FNUZ
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_FNUZ
;
#endif
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
static
constexpr
uint8_t
abs_mask
=
0x7F
;
static
constexpr
int
PackedSize
=
1
;
};
};
template
<
>
template
<
>
...
@@ -242,6 +243,7 @@ struct numeric_traits<bf8_t>
...
@@ -242,6 +243,7 @@ struct numeric_traits<bf8_t>
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_FNUZ
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_FNUZ
;
#endif
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
static
constexpr
uint8_t
abs_mask
=
0x7F
;
static
constexpr
int
PackedSize
=
1
;
};
};
// below is sw fp8 conversion, not utilizing hw instruction
// below is sw fp8 conversion, not utilizing hw instruction
...
...
include/ck_tile/core/numeric/half.hpp
View file @
f0bbc5db
...
@@ -241,6 +241,7 @@ struct numeric_traits<half_t>
...
@@ -241,6 +241,7 @@ struct numeric_traits<half_t>
static
constexpr
uint16_t
NegInf
=
0xFC00
;
static
constexpr
uint16_t
NegInf
=
0xFC00
;
static
constexpr
uint16_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
Neg0
=
0x8000
;
static
constexpr
uint16_t
Neg0
=
0x8000
;
static
constexpr
int
PackedSize
=
1
;
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
...
@@ -383,4 +384,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
...
@@ -383,4 +384,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
CK_TILE_DEVICE
CK_TILE_DEVICE
half_t
log
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
half_t
log
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
#endif
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
CK_TILE_HOST
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
CK_TILE_DEVICE
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
x
),
"v"
(
y
));
return
c
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/numeric/int8.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/half.hpp"
...
@@ -91,6 +91,7 @@ struct numeric_traits<int8_t>
...
@@ -91,6 +91,7 @@ struct numeric_traits<int8_t>
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
using bitwise_type = uint16_t;
};
};
#endif
#endif
...
...
include/ck_tile/core/numeric/numeric.hpp
View file @
f0bbc5db
...
@@ -94,6 +94,7 @@ struct numeric_traits<float>
...
@@ -94,6 +94,7 @@ struct numeric_traits<float>
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
int
PackedSize
=
1
;
using
bitwise_type
=
uint32_t
;
using
bitwise_type
=
uint32_t
;
};
};
...
...
include/ck_tile/core/numeric/pk_int4.hpp
View file @
f0bbc5db
...
@@ -21,8 +21,8 @@ struct pk_int4_t
...
@@ -21,8 +21,8 @@ struct pk_int4_t
{
{
using
type
=
int8_t
;
using
type
=
int8_t
;
type
data
;
type
data
;
__host__
__device__
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
CK_TILE_HOST_DEVICE
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
CK_TILE_HOST_DEVICE
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
};
};
// limits
// limits
...
@@ -91,6 +91,19 @@ struct numeric<pk_int4_t>
...
@@ -91,6 +91,19 @@ struct numeric<pk_int4_t>
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
};
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
pk_int4_t
>
{
static
constexpr
int
PackedSize
=
2
;
};
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
{
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
...
...
include/ck_tile/core/numeric/vector_type.hpp
View file @
f0bbc5db
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -34,7 +35,11 @@ template <typename T_, index_t N_>
...
@@ -34,7 +35,11 @@ template <typename T_, index_t N_>
struct
ext_vector
struct
ext_vector
{
{
static
constexpr
index_t
N
=
N_
;
static
constexpr
index_t
N
=
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
T_
>>::
type
;
// struct type is not supported for ext_vector
using
value_type
=
std
::
conditional_t
<
std
::
is_same_v
<
typename
native_t
<
remove_cvref_t
<
T_
>>::
type
,
pk_int4_t
>
,
int8_t
,
typename
native_t
<
remove_cvref_t
<
T_
>>::
type
>
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
};
...
@@ -58,7 +63,8 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
...
@@ -58,7 +63,8 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
template
<
typename
T
>
template
<
typename
T
>
struct
vector_traits
struct
vector_traits
{
{
using
scalar_type
=
remove_cvref_t
<
T
>
;
using
scalar_type
=
std
::
conditional_t
<
std
::
is_same_v
<
remove_cvref_t
<
T
>
,
pk_int4_t
>
,
int8_t
,
remove_cvref_t
<
T
>>
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
...
@@ -66,7 +72,7 @@ struct vector_traits
...
@@ -66,7 +72,7 @@ struct vector_traits
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
struct
vector_traits
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
{
{
using
scalar_type
=
T
;
using
scalar_type
=
std
::
conditional_t
<
std
::
is_same_v
<
T
,
pk_int4_t
>
,
int8_t
,
T
>
;
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
...
@@ -200,21 +206,12 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
...
@@ -200,21 +206,12 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
#endif
CK_TILE_HOST
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
// pk_int4_t
{
// using pk_int4_t
fp16x2_t
vector_res
;
using
pk_int4x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
pk_int4x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
vector_res
.
x
=
x
.
x
+
y
.
x
;
using
pk_int4x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
vector_res
.
y
=
x
.
y
+
y
.
y
;
using
pk_int4x16_t
=
int8_t
__attribute
((
ext_vector_type
(
16
)));
using
pk_int4x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
return
vector_res
;
using
pk_int4x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
}
CK_TILE_DEVICE
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
x
),
"v"
(
y
));
return
c
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/buffer_view.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -231,6 +231,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -231,6 +231,8 @@ struct buffer_view<address_space_enum::global,
int32x4_t
cached_buf_res_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
static
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
T
>>::
PackedSize
;
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
{
...
@@ -255,7 +257,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -255,7 +257,8 @@ struct buffer_view<address_space_enum::global,
// Must call for buffers that need *_raw load/store
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
CK_TILE_HOST_DEVICE
void
init_raw
()
{
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
(
buffer_size_
/
PackedSize
)
*
sizeof
(
type
));
}
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
...
@@ -307,7 +310,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -307,7 +310,7 @@ struct buffer_view<address_space_enum::global,
t_per_x
,
t_per_x
,
Coherence
,
Coherence
,
oob_conditional_check
>
(
oob_conditional_check
>
(
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
/
PackedSize
);
}
}
else
else
{
{
...
@@ -318,7 +321,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -318,7 +321,7 @@ struct buffer_view<address_space_enum::global,
oob_conditional_check
>
(
p_data_
,
oob_conditional_check
>
(
p_data_
,
i
+
linear_offset
,
i
+
linear_offset
,
is_valid_element
,
is_valid_element
,
buffer_size_
,
buffer_size_
/
PackedSize
,
invalid_element_value_
);
invalid_element_value_
);
}
}
}
}
...
@@ -533,7 +536,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -533,7 +536,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
/
PackedSize
);
}
}
else
else
{
{
...
@@ -569,7 +572,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -569,7 +572,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
/
PackedSize
);
}
}
template
<
typename
X
,
template
<
typename
X
,
...
@@ -614,7 +617,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -614,7 +617,7 @@ struct buffer_view<address_space_enum::global,
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
/
PackedSize
);
}
}
else
else
{
{
...
@@ -654,7 +657,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -654,7 +657,7 @@ struct buffer_view<address_space_enum::global,
Coherence
,
Coherence
,
oob_conditional_check
,
oob_conditional_check
,
pre_nop
>
(
pre_nop
>
(
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
/
PackedSize
);
}
}
template
<
typename
X
,
template
<
typename
X
,
...
@@ -688,7 +691,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -688,7 +691,7 @@ struct buffer_view<address_space_enum::global,
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
/
PackedSize
);
}
}
else
if
(
is_valid_element
)
else
if
(
is_valid_element
)
{
{
...
@@ -897,83 +900,124 @@ struct buffer_view<address_space_enum::lds,
...
@@ -897,83 +900,124 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// ds_write_b128
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
static_assert
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
static_assert
(
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
),
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
"wrong! not implemented for this combination, please add "
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
"implementation"
);
// ext_vector_type for pk_int4 must use int8_t as type
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
1
>>::
value
)
||
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
2
>>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
4
>>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
8
>>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
16
>>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
4
>>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
8
>>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
16
>>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
1
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
2
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
4
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
8
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
16
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x4_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int32x4_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
4
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
8
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
else
if
constexpr
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
pk_int4x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
thread_buffer
<
pk_int4_t
,
16
>>::
value
))
{
{
// HACK: cast pointer of x is bad
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
// TODO: remove this after compiler fix
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -27,6 +27,8 @@ struct static_distributed_tensor
...
@@ -27,6 +27,8 @@ struct static_distributed_tensor
using
ThreadTensorDesc
=
using
ThreadTensorDesc
=
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
DataType
>>::
PackedSize
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
...
@@ -59,7 +61,7 @@ struct static_distributed_tensor
...
@@ -59,7 +61,7 @@ struct static_distributed_tensor
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_thread_buffer_size
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_thread_buffer_size
()
{
{
return
kThreadElementSpaceSize
;
return
kThreadElementSpaceSize
/
PackedSize
;
}
}
template
<
index_t
...
YSliceOrigins
,
index_t
...
YSliceLengths
>
template
<
index_t
...
YSliceOrigins
,
index_t
...
YSliceLengths
>
...
@@ -79,8 +81,9 @@ struct static_distributed_tensor
...
@@ -79,8 +81,9 @@ struct static_distributed_tensor
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
sliced_thread_data
(
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
>
{})
=
sliced_thread_data
(
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
>
{}];
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
/
PackedSize
>
{})
=
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
/
PackedSize
>
{}];
});
});
return
sliced_thread_data
;
return
sliced_thread_data
;
...
@@ -101,8 +104,9 @@ struct static_distributed_tensor
...
@@ -101,8 +104,9 @@ struct static_distributed_tensor
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
>
{})
=
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
/
PackedSize
>
{})
=
sliced_thread_data
[
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
>
{}];
sliced_thread_data
[
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
/
PackedSize
>
{}];
});
});
}
}
...
@@ -115,7 +119,7 @@ struct static_distributed_tensor
...
@@ -115,7 +119,7 @@ struct static_distributed_tensor
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
TileDistributedIndices
{});
TileDistributedIndices
{});
return
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
>
{}];
return
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
/
PackedSize
>
{}];
}
}
template
<
typename
TileDistributedIndices
>
template
<
typename
TileDistributedIndices
>
...
@@ -127,11 +131,11 @@ struct static_distributed_tensor
...
@@ -127,11 +131,11 @@ struct static_distributed_tensor
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
TileDistributedIndices
{});
TileDistributedIndices
{});
return
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
>
{});
return
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
/
PackedSize
>
{});
}
}
//
//
thread_buffer
<
DataType
,
kThreadElementSpaceS
ize
>
thread_buf_
;
thread_buffer
<
DataType
,
get_thread_buffer_s
ize
()
>
thread_buf_
;
};
};
template
<
typename
DataType
,
typename
StaticTileDistribution
>
template
<
typename
DataType
,
typename
StaticTileDistribution
>
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -45,6 +45,8 @@ struct tensor_view
...
@@ -45,6 +45,8 @@ struct tensor_view
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
static
constexpr
auto
DstInMemOp
=
DstInMemOp_
;
static
constexpr
auto
DstInMemOp
=
DstInMemOp_
;
static
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
DataType
>>::
PackedSize
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
...
@@ -81,8 +83,8 @@ struct tensor_view
...
@@ -81,8 +83,8 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
return
buf_
.
template
get
<
X
>(
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
}
}
...
@@ -99,8 +101,8 @@ struct tensor_view
...
@@ -99,8 +101,8 @@ struct tensor_view
bool
is_valid_element
,
// flag
bool
is_valid_element
,
// flag
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
is_valid_element
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
}
}
...
@@ -122,8 +124,8 @@ struct tensor_view
...
@@ -122,8 +124,8 @@ struct tensor_view
{
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
dst
,
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -142,8 +144,12 @@ struct tensor_view
...
@@ -142,8 +144,12 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
dst
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
coord
.
get_offset
()
/
PackedSize
,
linear_offset
/
PackedSize
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
X
,
template
<
typename
X
,
...
@@ -159,8 +165,8 @@ struct tensor_view
...
@@ -159,8 +165,8 @@ struct tensor_view
{
{
return
buf_
.
template
async_get
<
X
>(
return
buf_
.
template
async_get
<
X
>(
smem
,
smem
,
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
}
}
...
@@ -178,8 +184,8 @@ struct tensor_view
...
@@ -178,8 +184,8 @@ struct tensor_view
bool
is_valid_element
)
const
bool
is_valid_element
)
const
{
{
return
buf_
.
template
async_get
<
X
>(
smem
,
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
is_valid_element
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
}
}
...
@@ -198,8 +204,8 @@ struct tensor_view
...
@@ -198,8 +204,8 @@ struct tensor_view
{
{
return
buf_
.
template
async_get_raw
<
X
>(
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
smem
,
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -217,8 +223,11 @@ struct tensor_view
...
@@ -217,8 +223,11 @@ struct tensor_view
bool
is_valid_element
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
async_get_raw
<
X
>(
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
coord
.
get_offset
()
/
PackedSize
,
linear_offset
/
PackedSize
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
}
// X is vector of DataType.
// X is vector of DataType.
...
@@ -236,8 +245,8 @@ struct tensor_view
...
@@ -236,8 +245,8 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
...
@@ -272,8 +281,8 @@ struct tensor_view
...
@@ -272,8 +281,8 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
...
@@ -292,7 +301,7 @@ struct tensor_view
...
@@ -292,7 +301,7 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
coord
.
get_offset
()
/
PackedSize
,
linear_offset
/
PackedSize
,
is_valid_element
,
x
);
}
}
// X is vector of DataType.
// X is vector of DataType.
...
@@ -310,8 +319,8 @@ struct tensor_view
...
@@ -310,8 +319,8 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
...
@@ -330,7 +339,7 @@ struct tensor_view
...
@@ -330,7 +339,7 @@ struct tensor_view
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
coord
.
get_offset
()
/
PackedSize
,
linear_offset
/
PackedSize
,
is_valid_element
,
x
);
}
}
// X is vector of DataType.
// X is vector of DataType.
...
@@ -350,8 +359,8 @@ struct tensor_view
...
@@ -350,8 +359,8 @@ struct tensor_view
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
coord
.
get_offset
()
/
PackedSize
,
linear_offset
,
linear_offset
/
PackedSize
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
...
@@ -372,7 +381,7 @@ struct tensor_view
...
@@ -372,7 +381,7 @@ struct tensor_view
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
coord
.
get_offset
()
/
PackedSize
,
linear_offset
/
PackedSize
,
is_valid_element
,
x
);
}
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
f0bbc5db
...
@@ -97,13 +97,15 @@ struct tile_window_with_static_distribution
...
@@ -97,13 +97,15 @@ struct tile_window_with_static_distribution
}
}
public:
public:
static
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
DataType
>>::
PackedSize
;
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
ScalarPerVector
=
static
constexpr
index_t
ScalarPerVector
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
// using vector_t = typename vector_type_t::type;
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
>
;
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
/
PackedSize
>
;
private:
private:
static
constexpr
auto
scalars_per_access_
=
[]
{
static
constexpr
auto
scalars_per_access_
=
[]
{
...
@@ -336,7 +338,7 @@ struct tile_window_with_static_distribution
...
@@ -336,7 +338,7 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
#if 1
// write into distributed tensor
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
Traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
...
@@ -345,10 +347,11 @@ struct tile_window_with_static_distribution
...
@@ -345,10 +347,11 @@ struct tile_window_with_static_distribution
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
Traits
::
PackedSize
;
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
vec_value
.
template
get_as
<
DataType
>()[
j
/
Traits
::
PackedSize
];
});
});
#else
#else
constexpr
index_t
d
=
constexpr
index_t
d
=
...
@@ -390,8 +393,9 @@ struct tile_window_with_static_distribution
...
@@ -390,8 +393,9 @@ struct tile_window_with_static_distribution
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
static
constexpr
index_t
YElementSize
=
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
static_assert
(
YElementSize
%
Traits
::
ScalarPerVector
==
0
);
static_assert
(
YElementSize
%
(
Traits
::
PackedSize
*
Traits
::
ScalarPerVector
)
==
0
);
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
Traits
::
ScalarPerVector
>
;
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
(
Traits
::
PackedSize
*
Traits
::
ScalarPerVector
)
>
;
// StaticBuffer<address_space_enum::vgpr,
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
// YElementSize / Traits::ScalarPerVector,
...
@@ -419,7 +423,8 @@ struct tile_window_with_static_distribution
...
@@ -419,7 +423,8 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
)
/
Traits
::
PackedSize
;
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
...
@@ -632,7 +637,7 @@ struct tile_window_with_static_distribution
...
@@ -632,7 +637,7 @@ struct tile_window_with_static_distribution
// vector_type_t vec;
// vector_type_t vec;
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
Traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
...
@@ -641,9 +646,10 @@ struct tile_window_with_static_distribution
...
@@ -641,9 +646,10 @@ struct tile_window_with_static_distribution
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
Traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
/
Traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -698,7 +704,7 @@ struct tile_window_with_static_distribution
...
@@ -698,7 +704,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
Traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
...
@@ -706,8 +712,9 @@ struct tile_window_with_static_distribution
...
@@ -706,8 +712,9 @@ struct tile_window_with_static_distribution
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
Traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
/
Traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -759,7 +766,7 @@ struct tile_window_with_static_distribution
...
@@ -759,7 +766,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
Traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
...
@@ -768,9 +775,10 @@ struct tile_window_with_static_distribution
...
@@ -768,9 +775,10 @@ struct tile_window_with_static_distribution
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
Traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
/
Traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -825,7 +833,7 @@ struct tile_window_with_static_distribution
...
@@ -825,7 +833,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
Traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
...
@@ -834,9 +842,10 @@ struct tile_window_with_static_distribution
...
@@ -834,9 +842,10 @@ struct tile_window_with_static_distribution
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
Traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
/
Traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
...
@@ -151,11 +151,13 @@ struct tile_window_linear
...
@@ -151,11 +151,13 @@ struct tile_window_linear
}
}
public:
public:
static
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
DataType
>>::
PackedSize
;
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
ScalarPerVector
=
static
constexpr
index_t
ScalarPerVector
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
>
;
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
/
PackedSize
>
;
private:
private:
static
constexpr
auto
scalars_per_access_
=
[]
{
static
constexpr
auto
scalars_per_access_
=
[]
{
...
@@ -498,17 +500,18 @@ struct tile_window_linear
...
@@ -498,17 +500,18 @@ struct tile_window_linear
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
traits
::
ScalarPerVector
,
traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
traits
::
PackedSize
;
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
vec_value
.
template
get_as
<
DataType
>()[
j
/
traits
::
PackedSize
];
});
});
#else
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
...
@@ -556,17 +559,18 @@ struct tile_window_linear
...
@@ -556,17 +559,18 @@ struct tile_window_linear
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
traits
::
ScalarPerVector
,
traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
traits
::
PackedSize
;
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
vec_value
.
template
get_as
<
DataType
>()[
j
/
traits
::
PackedSize
];
});
});
#else
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
...
@@ -595,8 +599,9 @@ struct tile_window_linear
...
@@ -595,8 +599,9 @@ struct tile_window_linear
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
static
constexpr
index_t
YElementSize
=
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
static_assert
(
YElementSize
%
traits
::
ScalarPerVector
==
0
);
static_assert
(
YElementSize
%
(
traits
::
PackedSize
*
traits
::
ScalarPerVector
)
==
0
);
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
traits
::
ScalarPerVector
>
;
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
(
traits
::
PackedSize
*
traits
::
ScalarPerVector
)
>
;
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
...
@@ -620,7 +625,9 @@ struct tile_window_linear
...
@@ -620,7 +625,9 @@ struct tile_window_linear
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
)
/
traits
::
PackedSize
;
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
...
@@ -804,16 +811,17 @@ struct tile_window_linear
...
@@ -804,16 +811,17 @@ struct tile_window_linear
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
traits
::
ScalarPerVector
,
traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
/
traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -852,14 +860,15 @@ struct tile_window_linear
...
@@ -852,14 +860,15 @@ struct tile_window_linear
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
traits
::
ScalarPerVector
,
traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
/
traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -897,16 +906,17 @@ struct tile_window_linear
...
@@ -897,16 +906,17 @@ struct tile_window_linear
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
traits
::
ScalarPerVector
,
traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
/
traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -948,16 +958,17 @@ struct tile_window_linear
...
@@ -948,16 +958,17 @@ struct tile_window_linear
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
traits
::
ScalarPerVector
,
traits
::
PackedSize
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
)
/
traits
::
PackedSize
;
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
/
traits
::
PackedSize
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
...
include/ck_tile/host/check_err.hpp
View file @
f0bbc5db
...
@@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
double
compute_error
=
0
;
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
if
constexpr
(
is_any_of
<
ComputeDataType
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
)
{
{
return
0
;
return
0
;
}
}
...
@@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
if
constexpr
(
is_any_of
<
OutDataType
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
)
{
{
return
0
;
return
0
;
}
}
...
@@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
if
constexpr
(
is_any_of
<
AccDataType
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
)
{
{
return
0
;
return
0
;
}
}
...
@@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
double
compute_error
=
0
;
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
if
constexpr
(
is_any_of
<
ComputeDataType
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
)
{
{
return
0
;
return
0
;
}
}
...
@@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
if
constexpr
(
is_any_of
<
OutDataType
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
)
{
{
return
0
;
return
0
;
}
}
...
@@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
if
constexpr
(
is_any_of
<
AccDataType
,
pk_int4_t
,
I8
,
I32
,
int
>::
value
)
{
{
return
0
;
return
0
;
}
}
...
...
include/ck_tile/host/fill.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -282,7 +282,14 @@ struct FillMonotonicSeq
...
@@ -282,7 +282,14 @@ struct FillMonotonicSeq
{
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
init_value_
]()
mutable
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
init_value_
]()
mutable
{
auto
tmp
=
n
;
auto
tmp
=
n
;
n
+=
step_
;
if
constexpr
(
std
::
is_same_v
<
decltype
(
tmp
),
pk_int4_t
>
)
{
n
.
data
+=
step_
.
data
;
}
else
{
n
+=
step_
;
}
return
tmp
;
return
tmp
;
});
});
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment