Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bf8e6de7
Commit
bf8e6de7
authored
Sep 03, 2024
by
carlushuang
Browse files
support argmax reduce in test
parent
ee956e8e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
291 additions
and
17 deletions
+291
-17
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+70
-9
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+17
-0
test/tile_reduce/tile_reduce.cpp
test/tile_reduce/tile_reduce.cpp
+204
-8
No files found.
include/ck_tile/core/arch/utility.hpp
View file @
bf8e6de7
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include <stdint.h>
#include <stdint.h>
...
@@ -33,14 +34,44 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
...
@@ -33,14 +34,44 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
#if 0
#if 0
return __shfl_up(v_local, lane_delta);
return __shfl_up(v_local, lane_delta);
#elif
1
#elif
1
static_assert
(
sizeof
(
T
)
==
sizeof
(
int32_t
),
"wrong!"
);
const
uint32_t
wrap_around_lane_delta
=
warpSize
-
lane_delta
;
const
uint32_t
wrap_around_lane_delta
=
warpSize
-
lane_delta
;
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
(
__lane_id
()
<<
2
)
+
(
wrap_around_lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
wrap_around_lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
wrap_around_lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
wrap_around_lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
#endif
}
}
...
@@ -50,12 +81,42 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
...
@@ -50,12 +81,42 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#if 0
#if 0
return __shfl_down(v_local, lane_delta);
return __shfl_down(v_local, lane_delta);
#elif
1
#elif
1
static_assert
(
sizeof
(
T
)
==
sizeof
(
int32_t
),
"wrong!"
);
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
((
__lane_id
()
<<
2
)
+
(
lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
p
));
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
return
p_remote
.
v
;
(
__lane_id
()
<<
2
)
+
(
lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
((
__lane_id
()
<<
2
)
+
(
lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
#endif
}
}
...
...
include/ck_tile/core/numeric/math.hpp
View file @
bf8e6de7
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <type_traits>
#include <type_traits>
#include <stdint.h>
#include <stdint.h>
#include <cmath>
#include <cmath>
...
@@ -165,12 +166,28 @@ CK_TILE_HOST constexpr T max(T x, T y)
...
@@ -165,12 +166,28 @@ CK_TILE_HOST constexpr T max(T x, T y)
return
x
>
y
?
x
:
y
;
return
x
>
y
?
x
:
y
;
}
}
template
<
>
CK_TILE_HOST
constexpr
fp16_t
max
(
fp16_t
x
,
fp16_t
y
)
{
float
x_
=
static_cast
<
float
>
(
x
);
float
y_
=
static_cast
<
float
>
(
y
);
return
x_
>
y_
?
x
:
y
;
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
{
return
x
>
y
?
x
:
y
;
return
x
>
y
?
x
:
y
;
}
}
template
<
>
CK_TILE_DEVICE
fp16_t
max
(
fp16_t
x
,
fp16_t
y
)
{
fp16_t
rtn
;
asm
volatile
(
"v_max_f16 %0, %1, %2"
:
"=v"
(
rtn
)
:
"v"
(
x
),
"v"
(
y
));
return
rtn
;
}
template
<
>
template
<
>
CK_TILE_DEVICE
constexpr
float
max
(
float
x
,
float
y
)
CK_TILE_DEVICE
constexpr
float
max
(
float
x
,
float
y
)
{
{
...
...
test/tile_reduce/tile_reduce.cpp
View file @
bf8e6de7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/reduce.hpp"
#ifndef TEST_TILE_REDUCE_VERBOSE
#ifndef TEST_TILE_REDUCE_VERBOSE
#define TEST_TILE_REDUCE_VERBOSE
1
#define TEST_TILE_REDUCE_VERBOSE
0
#endif
#endif
#define HIP_CALL(call) \
#define HIP_CALL(call) \
...
@@ -88,17 +88,16 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
...
@@ -88,17 +88,16 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
auto
data
=
load_tile
(
src_tile
);
auto
data
=
load_tile
(
src_tile
);
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
ck_tile
::
max
(
e0
,
e1
);
};
// Note: the return type will fill the replicate dim
// Note: the return type will fill the replicate dim
// usually is 2d. This is for the next block_tile_reduce_sync()
// usually is 2d. The hlength of r is 1d.
// This is for the next block_tile_reduce_sync()
// in order to do further reduce.
// in order to do further reduce.
auto
r
=
auto
r
=
block_tile_reduce
<
DataType
>
(
data
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
DataType
>::
infinity
());
block_tile_reduce
<
DataType
>
(
data
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
DataType
>::
infinity
());
// r.foo();
// further reduce cross thread, Note Now the HLength of r is 1D
// further reduce cross thread
block_tile_reduce_sync
(
r
,
f_max
,
bool_constant
<
false
>
{});
block_tile_reduce_sync
(
r
,
f_max
,
bool_constant
<
false
>
{});
if
(
threadIdx
.
x
%
col_lanes
==
0
)
if
(
threadIdx
.
x
%
col_lanes
==
0
)
...
@@ -109,6 +108,123 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
...
@@ -109,6 +108,123 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
}
}
}
}
template
<
int
Rows
,
int
Cols
,
typename
DataType
,
int
BytesPerIssue
=
16
>
__global__
void
reduce_row_argmax
(
DataType
*
p_src
,
DataType
*
p_dst
,
int
*
p_idx
)
{
using
namespace
ck_tile
;
// some constexpr vars
constexpr
index_t
vec
=
BytesPerIssue
/
sizeof
(
DataType
);
static_assert
(
Cols
%
vec
==
0
);
constexpr
index_t
col_lanes
=
Cols
/
vec
;
constexpr
index_t
warp_size
=
ck_tile
::
get_warp_size
();
static_assert
(
warp_size
%
col_lanes
==
0
);
constexpr
index_t
row_lanes
=
warp_size
/
col_lanes
;
constexpr
index_t
num_warps
=
BLOCK_SIZE
/
warp_size
;
static_assert
(
Rows
%
(
num_warps
*
row_lanes
)
==
0
);
constexpr
index_t
row_repeat
=
Rows
/
(
num_warps
*
row_lanes
);
auto
src_tile
=
[
&
]()
{
constexpr
auto
src_dist
=
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
col_lanes
,
vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
auto
src_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_src
,
make_tuple
(
Rows
,
Cols
),
make_tuple
(
Cols
,
1
),
number
<
vec
>
{},
// alignement
number
<
1
>
{});
return
make_tile_window
(
src_view
,
make_tuple
(
number
<
Rows
>
{},
number
<
Cols
>
{}),
{
0
,
0
},
src_dist
);
}();
constexpr
auto
dst_dist
=
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
col_lanes
>
,
// -> replicate here, hence we can figure out the offset
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
1
>
/* only 1 per row*/
>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{});
auto
dst_tile
=
[
&
]()
{
auto
dst_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_dst
,
make_tuple
(
Rows
,
1
),
make_tuple
(
1
,
1
),
number
<
1
>
{},
// alignement
number
<
1
>
{});
return
make_tile_window
(
dst_view
,
make_tuple
(
number
<
Rows
>
{},
number
<
1
>
{}),
{
0
,
0
},
dst_dist
);
}();
auto
idx_tile
=
[
&
]()
{
auto
idx_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_idx
,
make_tuple
(
Rows
,
1
),
make_tuple
(
1
,
1
),
number
<
1
>
{},
// alignement
number
<
1
>
{});
return
make_tile_window
(
idx_view
,
make_tuple
(
number
<
Rows
>
{},
number
<
1
>
{}),
{
0
,
0
},
dst_dist
);
}();
auto
data
=
load_tile
(
src_tile
);
struct
kv
{
DataType
arg
;
int
value
;
// this is col_id per row
};
auto
kv_data
=
make_static_distributed_tensor
<
kv
>
(
data
.
get_tile_distribution
());
// compute elementwise softmax
constexpr
auto
span_2d
=
decltype
(
kv_data
)
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
kv_data
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
kv
tmp
;
tmp
.
arg
=
data
(
i_j_idx
);
tmp
.
value
=
tile_idx
.
at
(
number
<
1
>
{});
kv_data
(
i_j_idx
)
=
tmp
;
});
});
const
auto
f_arg_max
=
[](
kv
e0
,
kv
e1
)
{
return
e0
.
arg
>
e1
.
arg
?
e0
:
e1
;
};
auto
arg_max_init
=
kv
{
-
numeric
<
DataType
>::
infinity
(),
0
};
auto
r
=
block_tile_reduce
<
kv
>
(
kv_data
,
sequence
<
1
>
{},
f_arg_max
,
arg_max_init
);
// further reduce cross thread, Note Now the HLength of r is 1D
block_tile_reduce_sync
(
r
,
f_arg_max
,
bool_constant
<
false
>
{});
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
int
>
(
dst_dist
);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
kv
tmp
=
r
(
i_j_idx
);
o
(
i_j_idx
)
=
tmp
.
arg
;
i
(
i_j_idx
)
=
tmp
.
value
;
});
});
if
(
threadIdx
.
x
%
col_lanes
==
0
)
{
store_tile
(
dst_tile
,
o
);
store_tile
(
idx_tile
,
i
);
}
}
template
<
int
Rows
,
int
Cols
,
typename
DataType
,
int
BytesPerIssue
=
16
>
template
<
int
Rows
,
int
Cols
,
typename
DataType
,
int
BytesPerIssue
=
16
>
bool
test_tile_reduce
()
bool
test_tile_reduce
()
{
{
...
@@ -154,12 +270,12 @@ bool test_tile_reduce()
...
@@ -154,12 +270,12 @@ bool test_tile_reduce()
}
}
{
{
uint32_t
ref
=
ck_tile
::
bit_cast
<
uint32_t
>
(
row_max
);
uint32_t
ref
=
ck_tile
::
bit_cast
<
uint32_t
>
(
row_max
);
uint32_t
out
=
ck_tile
::
bit_cast
<
uint32_t
>
(
dst
[
i_r
]);
uint32_t
out
=
ck_tile
::
bit_cast
<
uint32_t
>
(
ck_tile
::
type_convert
<
float
>
(
dst
[
i_r
])
)
;
if
(
ref
!=
out
)
if
(
ref
!=
out
)
err_cnt
++
;
err_cnt
++
;
}
}
#if TEST_TILE_REDUCE_VERBOSE
#if TEST_TILE_REDUCE_VERBOSE
printf
(
" -> %.3f (%.3f)
\n
"
,
dst
[
i_r
],
row_max
);
printf
(
" -> %.3f (%.3f)
\n
"
,
ck_tile
::
type_convert
<
float
>
(
dst
[
i_r
]
)
,
row_max
);
#endif
#endif
}
}
#if TEST_TILE_REDUCE_VERBOSE
#if TEST_TILE_REDUCE_VERBOSE
...
@@ -171,11 +287,91 @@ bool test_tile_reduce()
...
@@ -171,11 +287,91 @@ bool test_tile_reduce()
return
err_cnt
==
0
?
true
:
false
;
return
err_cnt
==
0
?
true
:
false
;
}
}
template
<
int
Rows
,
int
Cols
,
typename
DataType
,
int
BytesPerIssue
=
16
>
bool
test_tile_reduce_argmax
()
{
std
::
srand
(
std
::
time
(
nullptr
));
DataType
*
src
=
reinterpret_cast
<
DataType
*>
(
malloc
(
Rows
*
Cols
*
sizeof
(
DataType
)));
DataType
*
dst
=
reinterpret_cast
<
DataType
*>
(
malloc
(
Rows
*
sizeof
(
DataType
)));
int
*
idx
=
reinterpret_cast
<
int
*>
(
malloc
(
Rows
*
sizeof
(
int
)));
// const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
for
(
auto
i
=
0
;
i
<
Rows
*
Cols
;
i
++
)
{
float
v
=
static_cast
<
float
>
(
std
::
rand
()
%
2000
-
1000
)
/
1000.
f
;
src
[
i
]
=
ck_tile
::
type_convert
<
DataType
>
(
v
);
}
void
*
dev_src
;
void
*
dev_dst
;
void
*
dev_idx
;
HIP_CALL
(
hipMalloc
(
&
dev_src
,
Rows
*
Cols
*
sizeof
(
DataType
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst
,
Rows
*
sizeof
(
DataType
)));
HIP_CALL
(
hipMalloc
(
&
dev_idx
,
Rows
*
sizeof
(
int
)));
HIP_CALL
(
hipMemcpy
(
dev_src
,
src
,
Rows
*
Cols
*
sizeof
(
DataType
),
hipMemcpyHostToDevice
));
constexpr
int
bdim
=
BLOCK_SIZE
;
int
gdim
=
1
;
reduce_row_argmax
<
Rows
,
Cols
,
DataType
,
BytesPerIssue
>
<<<
gdim
,
bdim
>>>
(
reinterpret_cast
<
DataType
*>
(
dev_src
),
reinterpret_cast
<
DataType
*>
(
dev_dst
),
reinterpret_cast
<
int
*>
(
dev_idx
));
HIP_CALL
(
hipMemcpy
(
dst
,
dev_dst
,
Rows
*
sizeof
(
DataType
),
hipMemcpyDeviceToHost
));
HIP_CALL
(
hipMemcpy
(
idx
,
dev_idx
,
Rows
*
sizeof
(
int
),
hipMemcpyDeviceToHost
));
int
err_cnt
=
0
;
for
(
int
i_r
=
0
;
i_r
<
Rows
;
i_r
++
)
{
auto
row_max
=
-
ck_tile
::
numeric
<
float
>::
infinity
();
int
row_idx
=
-
1
;
for
(
int
i_c
=
0
;
i_c
<
Cols
;
i_c
++
)
{
int
idx_
=
i_r
*
Cols
+
i_c
;
float
v
=
ck_tile
::
type_convert
<
float
>
(
src
[
idx_
]);
row_max
=
row_max
>
v
?
row_max
:
v
;
row_idx
=
row_max
>
v
?
row_idx
:
i_c
;
#if TEST_TILE_REDUCE_VERBOSE
printf
(
"%.3f "
,
v
);
#endif
}
{
uint32_t
ref
=
ck_tile
::
bit_cast
<
uint32_t
>
(
row_max
);
uint32_t
out
=
ck_tile
::
bit_cast
<
uint32_t
>
(
ck_tile
::
type_convert
<
float
>
(
dst
[
i_r
]));
if
(
ref
!=
out
)
err_cnt
++
;
if
(
idx
[
i_r
]
!=
row_idx
)
err_cnt
++
;
}
#if TEST_TILE_REDUCE_VERBOSE
printf
(
" -> %.3f,%d (%.3f,%d)
\n
"
,
ck_tile
::
type_convert
<
float
>
(
dst
[
i_r
]),
idx
[
i_r
],
row_max
,
row_idx
);
#endif
}
#if TEST_TILE_REDUCE_VERBOSE
printf
(
"
\n
"
);
#endif
free
(
src
);
free
(
dst
);
free
(
idx
);
return
err_cnt
==
0
?
true
:
false
;
}
int
main
()
int
main
()
{
{
bool
r
=
true
;
bool
r
=
true
;
r
&=
test_tile_reduce
<
32
,
64
,
float
>
();
r
&=
test_tile_reduce
<
32
,
64
,
float
>
();
r
&=
test_tile_reduce
<
32
,
16
,
float
,
4
>
();
r
&=
test_tile_reduce
<
32
,
16
,
float
,
4
>
();
r
&=
test_tile_reduce
<
32
,
16
,
ck_tile
::
fp16_t
,
4
>
();
r
&=
test_tile_reduce_argmax
<
32
,
16
,
float
,
4
>
();
return
r
?
0
:
-
1
;
return
r
?
0
:
-
1
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment