Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3c5717df
Unverified
Commit
3c5717df
authored
Feb 10, 2025
by
Illia Silin
Committed by
GitHub
Feb 10, 2025
Browse files
Merge branch 'develop' into gemm_elementwise_gemm
parents
171b9030
d9f1ead3
Changes
454
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
965 additions
and
189 deletions
+965
-189
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+18
-0
include/ck_tile/core/utility/unary_element_function.hpp
include/ck_tile/core/utility/unary_element_function.hpp
+9
-7
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+4
-1
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+119
-3
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+27
-57
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+35
-0
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+107
-6
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+143
-19
include/ck_tile/host/joinable_thread.hpp
include/ck_tile/host/joinable_thread.hpp
+27
-0
include/ck_tile/host/reference/reference_batched_transpose.hpp
...de/ck_tile/host/reference/reference_batched_transpose.hpp
+59
-0
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+205
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+41
-68
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-5
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+21
-2
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
+30
-4
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
..._tile/host/reference/reference_rowwise_quantization2d.hpp
+1
-1
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-1
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
.../pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
+30
-7
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
...ipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
+20
-6
No files found.
Too many changes to show.
To preserve performance only
454 of 454+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/utility/type_traits.hpp
View file @
3c5717df
...
@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
...
@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop
#pragma clang diagnostic pop
}
}
template
<
typename
CompareTo
,
typename
...
Rest
>
struct
is_any_of
:
std
::
false_type
{
};
template
<
typename
CompareTo
,
typename
FirstType
>
struct
is_any_of
<
CompareTo
,
FirstType
>
:
std
::
is_same
<
CompareTo
,
FirstType
>
{
};
template
<
typename
CompareTo
,
typename
FirstType
,
typename
...
Rest
>
struct
is_any_of
<
CompareTo
,
FirstType
,
Rest
...
>
:
std
::
integral_constant
<
bool
,
std
::
is_same
<
CompareTo
,
FirstType
>::
value
||
is_any_of
<
CompareTo
,
Rest
...
>::
value
>
{
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/utility/unary_element_function.hpp
View file @
3c5717df
...
@@ -51,16 +51,18 @@ struct composes<F>
...
@@ -51,16 +51,18 @@ struct composes<F>
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
composes
(
Ts
&&
...)
->
composes
<
remove_cvref_t
<
Ts
>
...
>
;
__host__
__device__
composes
(
Ts
&&
...)
->
composes
<
remove_cvref_t
<
Ts
>
...
>
;
template
<
typename
To
>
template
<
typename
SaturateType
>
struct
saturates
struct
saturates
{
{
template
<
typename
From
>
// NOTE: this function does not return SaturateType value
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
From
&
from
)
const
// it is user's responsiblity to do further cast or not
->
std
::
enable_if_t
<
std
::
is_arithmetic_v
<
From
>
,
From
>
template
<
typename
AccType
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
AccType
&
a_
)
const
->
std
::
enable_if_t
<
std
::
is_arithmetic_v
<
AccType
>
,
AccType
>
{
{
return
clamp
(
from
,
return
clamp
(
a_
,
type_convert
<
From
>
(
numeric
<
To
>::
lowest
()),
type_convert
<
AccType
>
(
numeric
<
SaturateType
>::
lowest
()),
type_convert
<
From
>
(
numeric
<
To
>::
max
()));
type_convert
<
AccType
>
(
numeric
<
SaturateType
>::
max
()));
}
}
};
};
...
...
include/ck_tile/host.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
...
@@ -19,7 +20,9 @@
...
@@ -19,7 +20,9 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
...
...
include/ck_tile/host/arg_parser.hpp
View file @
3c5717df
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
/*
* a host side utility, arg parser for
* a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ...
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
*/
class
ArgParser
class
ArgParser
{
{
public:
public:
class
Arg
class
Arg
{
{
...
@@ -187,6 +190,45 @@ class ArgParser
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
return
value
;
}
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/host/check_err.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -18,6 +18,114 @@
...
@@ -18,6 +18,114 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
-
numeric_traits
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
-
numeric_traits
<
AccDataType
>::
mant
)
*
0.5
*
number_of_accumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
AccDataType
>::
mant
)
*
0.5
*
number_of_accumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
T
>
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
{
...
@@ -337,7 +445,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -337,7 +445,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
@@ -391,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -391,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
View file @
3c5717df
...
@@ -14,57 +14,41 @@ namespace detail {
...
@@ -14,57 +14,41 @@ namespace detail {
template
<
typename
OldLayout
>
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
if
constexpr
(
is_any_of
<
OldLayout
,
GNCW
,
GKCX
,
GNKW
>::
value
)
{
{
return
{
0
,
1
,
2
,
3
};
return
{
0
,
1
,
2
,
3
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNCHW
,
GKCYX
,
GNKHW
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
{
{
return
{
0
,
1
,
2
,
3
,
4
};
return
{
0
,
1
,
2
,
3
,
4
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNCDHW
,
GKCZYX
,
GNKDHW
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
if
constexpr
(
is_any_of
<
OldLayout
,
GNWC
,
GKXC
,
GNWK
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
{
{
return
{
0
,
1
,
3
,
2
};
return
{
0
,
1
,
3
,
2
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNHWC
,
GKYXC
,
GNHWK
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
{
{
return
{
0
,
1
,
4
,
2
,
3
};
return
{
0
,
1
,
4
,
2
,
3
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNDHWC
,
GKZYXC
,
GNDHWK
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
NWGC
,
KXGC
,
NWGK
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
{
{
return
{
2
,
0
,
3
,
1
};
return
{
2
,
0
,
3
,
1
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
NHWGC
,
KYXGC
,
NHWGK
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
{
{
return
{
3
,
0
,
4
,
1
,
2
};
return
{
3
,
0
,
4
,
1
,
2
};
}
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
else
if
constexpr
(
is_any_of
<
OldLayout
,
NDHWGC
,
KZYXGC
,
NDHWGK
>::
value
)
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
}
...
@@ -83,11 +67,11 @@ template <typename InLayout>
...
@@ -83,11 +67,11 @@ template <typename InLayout>
CK_TILE_HOST
HostTensorDescriptor
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
if
constexpr
(
is_any_of
<
InLayout
,
GNCW
,
GNCHW
,
GNCDHW
>::
value
)
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
...
@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
else
if
constexpr
(
is_any_of
<
InLayout
,
GNWC
,
GNHWC
,
GNDHWC
>::
value
)
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
...
@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
else
if
constexpr
(
is_any_of
<
InLayout
,
NWGC
,
NHWGC
,
NDHWGC
>::
value
)
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
@@ -139,11 +119,11 @@ template <typename WeiLayout>
...
@@ -139,11 +119,11 @@ template <typename WeiLayout>
CK_TILE_HOST
HostTensorDescriptor
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
if
constexpr
(
is_any_of
<
WeiLayout
,
KXC
,
KYXC
,
KZYXC
>::
value
)
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
{
if
(
param
.
G_
!=
1
)
if
(
param
.
G_
!=
1
)
{
{
...
@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
...
@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
else
if
constexpr
(
is_any_of
<
WeiLayout
,
GKCX
,
GKCYX
,
GKCZYX
>::
value
)
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
...
@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
...
@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
else
if
constexpr
(
is_any_of
<
WeiLayout
,
GKXC
,
GKYXC
,
GKZYXC
>::
value
)
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
...
@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
...
@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
else
if
constexpr
(
is_any_of
<
WeiLayout
,
KXGC
,
KYXGC
,
KZYXGC
>::
value
)
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
@@ -211,11 +185,11 @@ template <typename OutLayout>
...
@@ -211,11 +185,11 @@ template <typename OutLayout>
CK_TILE_HOST
HostTensorDescriptor
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
if
constexpr
(
is_any_of
<
OutLayout
,
GNKW
,
GNKHW
,
GNKDHW
>::
value
)
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
...
@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
// separate from legacy code above
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
else
if
constexpr
(
is_any_of
<
OutLayout
,
GNWK
,
GNHWK
,
GNDHWK
>::
value
)
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
...
@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
else
if
constexpr
(
is_any_of
<
OutLayout
,
NWGK
,
NHWGK
,
NDHWGK
>::
value
)
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
include/ck_tile/host/device_memory.hpp
View file @
3c5717df
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdint.h>
#include <stdexcept>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -36,6 +37,19 @@ struct DeviceMem
...
@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf
=
nullptr
;
mpDeviceBuf
=
nullptr
;
}
}
}
}
template
<
typename
T
>
DeviceMem
(
const
HostTensor
<
T
>&
t
)
:
mMemSize
(
t
.
get_element_space_size_in_bytes
())
{
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
ToDevice
(
t
.
data
());
}
void
Realloc
(
std
::
size_t
mem_size
)
void
Realloc
(
std
::
size_t
mem_size
)
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
@@ -92,6 +106,27 @@ struct DeviceMem
...
@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
}
}
// construct a host tensor with type T
template
<
typename
T
>
HostTensor
<
T
>
ToHost
(
std
::
size_t
cpySize
)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std
::
size_t
host_elements
=
(
cpySize
+
sizeof
(
T
)
-
1
)
/
sizeof
(
T
);
HostTensor
<
T
>
h_
({
host_elements
});
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
h_
.
data
(),
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
return
h_
;
}
template
<
typename
T
>
HostTensor
<
T
>
ToHost
()
{
return
ToHost
<
T
>
(
mMemSize
);
}
void
SetZero
()
const
void
SetZero
()
const
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
...
include/ck_tile/host/fill.hpp
View file @
3c5717df
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include <unordered_set>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -22,13 +23,44 @@ struct FillUniformDistribution
...
@@ -22,13 +23,44 @@ struct FillUniformDistribution
float
a_
{
-
5.
f
};
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
// ATTENTION: threaded does not guarantee the distribution between thread
bool
threaded
=
false
;
template
<
typename
ForwardIter
>
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
if
(
threaded
)
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
{
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
uint32_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
auto
total
=
static_cast
<
std
::
size_t
>
(
std
::
distance
(
first
,
last
));
auto
work_per_thread
=
static_cast
<
std
::
size_t
>
((
total
+
num_thread
-
1
)
/
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
total
);
auto
thread_f
=
[
this
,
total
,
iw_begin
,
iw_end
,
&
first
]
{
if
(
iw_begin
>
total
||
iw_end
>
total
)
return
;
// need to make each thread unique, add an offset to current seed
std
::
mt19937
gen
(
seed_
.
has_value
()
?
(
*
seed_
+
iw_begin
)
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
+
iw_begin
,
first
+
iw_end
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
};
threads
[
it
]
=
joinable_thread
(
thread_f
);
}
}
else
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
}
template
<
typename
ForwardRange
>
template
<
typename
ForwardRange
>
...
@@ -115,13 +147,44 @@ struct FillNormalDistribution
...
@@ -115,13 +147,44 @@ struct FillNormalDistribution
float
mean_
{
0.
f
};
float
mean_
{
0.
f
};
float
variance_
{
1.
f
};
float
variance_
{
1.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
// ATTENTION: threaded does not guarantee the distribution between thread
bool
threaded
=
false
;
template
<
typename
ForwardIter
>
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
if
(
threaded
)
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
{
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
uint32_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
auto
total
=
static_cast
<
std
::
size_t
>
(
std
::
distance
(
first
,
last
));
auto
work_per_thread
=
static_cast
<
std
::
size_t
>
((
total
+
num_thread
-
1
)
/
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
total
);
auto
thread_f
=
[
this
,
total
,
iw_begin
,
iw_end
,
&
first
]
{
if
(
iw_begin
>
total
||
iw_end
>
total
)
return
;
// need to make each thread unique, add an offset to current seed
std
::
mt19937
gen
(
seed_
.
has_value
()
?
(
*
seed_
+
iw_begin
)
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
+
iw_begin
,
first
+
iw_end
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
};
threads
[
it
]
=
joinable_thread
(
thread_f
);
}
}
else
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
}
template
<
typename
ForwardRange
>
template
<
typename
ForwardRange
>
...
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
...
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
}
}
};
};
template
<
typename
T
,
bool
IsAscending
=
true
>
struct
FillStepRange
{
float
start_value_
{
0
};
float
end_value_
{
3
};
float
step_
{
1
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
start_value_
]()
mutable
{
auto
tmp
=
n
;
n
+=
step_
;
if
constexpr
(
IsAscending
)
{
if
(
n
>
end_value_
)
n
=
start_value_
;
}
else
{
if
(
n
<
end_value_
)
n
=
start_value_
;
}
return
type_convert
<
T
>
(
tmp
);
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillStepRange
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
FillConstant
struct
FillConstant
{
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -8,12 +8,13 @@
...
@@ -8,12 +8,13 @@
#include <iostream>
#include <iostream>
#include <iomanip>
#include <iomanip>
#include <numeric>
#include <numeric>
#include <thread>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <functional>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/ranges.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
...
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
}
}
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
typename
F
,
typename
...
Xs
>
template
<
typename
F
,
typename
...
Xs
>
struct
ParallelTensorFunctor
struct
ParallelTensorFunctor
{
{
...
@@ -590,7 +574,147 @@ struct HostTensor
...
@@ -590,7 +574,147 @@ struct HostTensor
size
()
*
FromSize
/
ToSize
};
size
()
*
FromSize
/
ToSize
};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensor
<
T
>&
t
)
{
os
<<
t
.
mDesc
;
os
<<
"["
;
for
(
typename
Data
::
size_type
idx
=
0
;
idx
<
t
.
mData
.
size
();
++
idx
)
{
if
(
0
<
idx
)
{
os
<<
", "
;
}
if
constexpr
(
std
::
is_same_v
<
T
,
bf16_t
>
||
std
::
is_same_v
<
T
,
fp16_t
>
)
{
os
<<
type_convert
<
float
>
(
t
.
mData
[
idx
])
<<
" #### "
;
}
else
{
os
<<
t
.
mData
[
idx
];
}
}
os
<<
"]"
;
return
os
;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void
loadtxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ifstream
file
(
file_name
);
if
(
file
.
is_open
())
{
std
::
string
line
;
index_t
cnt
=
0
;
while
(
std
::
getline
(
file
,
line
))
{
if
(
cnt
>=
static_cast
<
index_t
>
(
mData
.
size
()))
{
throw
std
::
runtime_error
(
std
::
string
(
"data read from file:"
)
+
file_name
+
" is too big"
);
}
if
(
dtype
==
"float"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stof
(
line
));
}
else
if
(
dtype
==
"int"
||
dtype
==
"int32"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stoi
(
line
));
}
cnt
++
;
}
file
.
close
();
if
(
cnt
<
static_cast
<
index_t
>
(
mData
.
size
()))
{
std
::
cerr
<<
"Warning! reading from file:"
<<
file_name
<<
", does not match the size of this tensor"
<<
std
::
endl
;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void
savetxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ofstream
file
(
file_name
);
if
(
file
.
is_open
())
{
for
(
auto
&
itm
:
mData
)
{
if
(
dtype
==
"float"
)
file
<<
type_convert
<
float
>
(
itm
)
<<
std
::
endl
;
else
if
(
dtype
==
"int"
)
file
<<
type_convert
<
int
>
(
itm
)
<<
std
::
endl
;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file
<<
type_convert
<
float
>
(
itm
)
<<
std
::
endl
;
}
file
.
close
();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
Descriptor
mDesc
;
Descriptor
mDesc
;
Data
mData
;
Data
mData
;
};
};
template
<
bool
is_row_major
>
auto
host_tensor_descriptor
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool_constant
<
is_row_major
>
)
{
using
namespace
ck_tile
::
literals
;
if
constexpr
(
is_row_major
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
}
template
<
bool
is_row_major
>
auto
get_default_stride
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool_constant
<
is_row_major
>
)
{
if
(
stride
==
0
)
{
if
constexpr
(
is_row_major
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/joinable_thread.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace
ck_tile
{
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
}
// namespace ck_tile
include/ck_tile/host/reference/reference_batched_transpose.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
Type
>
CK_TILE_HOST
void
reference_batched_transpose
(
const
HostTensor
<
Type
>&
x
,
HostTensor
<
Type
>&
y
,
std
::
string
layout_in
=
"NCHW"
,
std
::
string
layout_out
=
"NHWC"
)
{
const
int
N
=
x
.
mDesc
.
get_lengths
()[
0
];
auto
f
=
[
&
](
auto
batch
)
{
if
(
layout_in
==
"NCHW"
&&
layout_out
==
"NHWC"
)
{
const
int
C
=
x
.
mDesc
.
get_lengths
()[
1
];
const
int
H
=
x
.
mDesc
.
get_lengths
()[
2
];
const
int
W
=
x
.
mDesc
.
get_lengths
()[
3
];
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
Type
v_x
=
x
(
batch
,
c
,
h
,
w
);
y
(
batch
,
h
,
w
,
c
)
=
v_x
;
}
}
}
}
else
if
(
layout_in
==
"NHWC"
&&
layout_out
==
"NCHW"
)
{
const
int
H
=
x
.
mDesc
.
get_lengths
()[
1
];
const
int
W
=
x
.
mDesc
.
get_lengths
()[
2
];
const
int
C
=
x
.
mDesc
.
get_lengths
()[
3
];
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
Type
v_x
=
x
(
batch
,
h
,
w
,
c
);
y
(
batch
,
c
,
h
,
w
)
=
v_x
;
}
}
}
}
};
make_ParallelTensorFunctor
(
f
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_fused_moe.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template
<
typename
AccDataType
,
// you only need to explcitly set this one
typename
Activation
,
// ck_tile::element_wise::Gelu
typename
ADataType
,
typename
GDataType
,
typename
DDataType
,
typename
ODataType
,
typename
AScaleDataType
,
typename
GScaleDataType
,
typename
DScaleDataType
,
typename
YSmoothScaleDataType
,
typename
TopkWeightDataType
,
typename
IndexDataType
>
void
reference_fused_moe
(
const
ck_tile
::
HostTensor
<
ADataType
>&
a_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
GDataType
>&
g_host
,
// [experts, interme_size_0, hidden_size]
const
ck_tile
::
HostTensor
<
DDataType
>&
d_host
,
// [experts, hidden_size, interme_size_1]
const
ck_tile
::
HostTensor
<
AScaleDataType
>&
sa_host
,
// [tokens, 1],
const
ck_tile
::
HostTensor
<
GScaleDataType
>&
sg_host
,
// [experts, 1, interme_size_0]
const
ck_tile
::
HostTensor
<
DScaleDataType
>&
sd_host
,
// [experts, 1, hidden_size],
const
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>&
sy_host
,
// [experts, 1, interme_size_0]
ck_tile
::
HostTensor
<
ODataType
>&
o_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_token_ids_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
TopkWeightDataType
>&
sorted_weight_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_expert_ids_host
,
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
num_sorted_tiles_host
,
// [1]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
token_ids_host
,
// [tokens, topk] --> ugly!!! remove in the future
ck_tile
::
index_t
block_m
,
ck_tile
::
index_t
tokens
,
ck_tile
::
index_t
experts
,
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up/down
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
{
assert
(
sorted_token_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_weight_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_expert_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
num_sorted_tiles_host
.
get_element_size
()
==
1
);
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
]
/
block_m
;
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
);
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
;
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
ck_tile
::
index_t
i_topk
=
i_token
>>
24
;
i_token
&=
0xffffff
;
if
(
i_token
>=
tokens
)
return
;
(
void
)
token_ids_host
;
#else
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
{
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
};
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
#endif
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
// first gemm
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_0
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
{
acc
+=
type_convert
<
AccDataType
>
(
a_host
(
i_token
,
i_k
))
*
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
}
acc_0
(
0
,
i_n
)
=
acc
;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
intermediate_size_1
});
if
(
gate_only
)
{
if
(
intermediate_size_1
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if
(
intermediate_size_1
*
2
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile
::
HostTensor
<
AccDataType
>
acc_1
({
1
,
hidden_size
});
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
intermediate_size_1
;
i_k
++
)
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
}
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
out_topk_tokens
(
i_token
,
i_topk
,
i_n
)
=
acc_1
(
0
,
i_n
);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor
(
f
,
max_num_tokens_padded
)(
1
);
// reduce
auto
r
=
[
&
](
auto
i_token
)
{
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
type_convert
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_topk
=
0
;
i_topk
<
topk
;
i_topk
++
)
{
acc
+=
out_topk_tokens
(
i_token
,
i_topk
,
i_n
);
}
o_host
(
i_token
,
i_n
)
=
type_convert
<
ODataType
>
(
acc
);
}
};
make_ParallelTensorFunctor
(
r
,
tokens
)(
std
::
thread
::
hardware_concurrency
());
(
void
)
num_sorted_tiles_host
;
(
void
)
sa_host
;
(
void
)
sg_host
;
(
void
)
sd_host
;
(
void
)
sy_host
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
3c5717df
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
C
[
c_index
]
=
ck_tile
::
type_convert
<
CDataType
>
(
acc
)
;
}
}
}
}
...
@@ -97,9 +98,9 @@ template <typename ADataType,
...
@@ -97,9 +98,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -107,78 +108,50 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -107,78 +108,50 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
stride_b
,
index_t
stride_b
,
index_t
stride_c
)
index_t
stride_c
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
errC
=
hipMemcpy
(
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
return
;
if
(
errA
!=
hipSuccess
)
}
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
template
<
typename
ADataType
,
if
(
errB
!=
hipSuccess
)
typename
BDataType
,
{
typename
AccDataType
,
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
typename
CDataType
,
}
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
ADataType
*
a_ptr
,
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
errC
=
hipFree
(
d_C
);
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
if
(
errC
!=
hipSuccess
)
{
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
ADataType
*
d_ATemp
=
a_ptr
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
b_ptr
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
c_ptr
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
}
return
;
return
;
...
...
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
3c5717df
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
WeightType
>&
weights
,
...
@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
...
@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
// allocate a temp buffer, and fill the value with [number_token|topk]
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
std
::
vector
<
IndexType
>
(
unit_size
,
MOE_SORTING_MOCK_ID
(
num_token
,
topk
)));
#else
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
...
@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
...
@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
expert_token_weights
[
e
].
resize
(
new_size
);
expert_token_weights
[
e
].
resize
(
new_size
);
for
(
index_t
i
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
i
<
new_size
;
i
++
)
for
(
index_t
i
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
i
<
new_size
;
i
++
)
{
{
expert_tokens
[
e
][
i
]
=
num_token
;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens
[
e
][
i
]
=
MOE_SORTING_MOCK_ID
(
num_token
,
topk
);
#else
expert_tokens
[
e
][
i
]
=
num_token
;
#endif
expert_token_weights
[
e
][
i
]
=
0
;
expert_token_weights
[
e
][
i
]
=
0
;
}
}
}
}
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens
[
e
][
idx
]
=
t
;
expert_tokens
[
e
][
idx
]
=
MOE_SORTING_MOCK_ID
(
t
,
k
);
#else
expert_tokens
[
e
][
idx
]
=
t
;
#endif
expert_token_weights
[
e
][
idx
]
=
w
;
expert_token_weights
[
e
][
idx
]
=
w
;
expert_slice_idxs
[
e
]
++
;
expert_slice_idxs
[
e
]
++
;
}
}
...
@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
...
@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
unit_cnt
*=
unit_size
;
unit_cnt
*=
unit_size
;
return
;
return
;
}
}
#undef MOE_SORTING_MOCK_ID
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_permute.hpp
View file @
3c5717df
...
@@ -16,7 +16,7 @@ namespace ck_tile {
...
@@ -16,7 +16,7 @@ namespace ck_tile {
*/
*/
template
<
typename
DataType
>
template
<
typename
DataType
>
CK_TILE_HOST
void
CK_TILE_HOST
void
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
dims
)
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
perm
)
{
{
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
{
{
tmp
[
dims
[
i
]]
=
y_coord
[
i
];
tmp
[
perm
[
i
]]
=
y_coord
[
i
];
}
}
return
tmp
;
return
tmp
;
}();
}();
...
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
DataType
>
CK_TILE_HOST
auto
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
{
auto
x_shape
=
x
.
get_lengths
();
ck_tile
::
index_t
rank
=
perm
.
size
();
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
{
tmp
[
i
]
=
x_shape
[
perm
[
i
]];
}
return
tmp
;
}();
HostTensor
<
DataType
>
y
(
y_shape
);
reference_permute
(
x
,
y
,
perm
);
return
y
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
View file @
3c5717df
...
@@ -8,16 +8,40 @@
...
@@ -8,16 +8,40 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Note: for simplicity, each functor only care about single M
struct
reference_rmsnorm2d_default_epilogue
{
template
<
typename
OutDataType
,
typename
AccDataType
>
void
operator
()(
int
m
,
HostTensor
<
OutDataType
>&
o
,
const
HostTensor
<
AccDataType
>&
acc
)
{
const
int
N
=
acc
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
o
(
m
,
n
)
=
ck_tile
::
type_convert
<
OutDataType
>
(
acc
(
m
,
n
));
}
}
template
<
typename
OutDataType
,
typename
AccDataType
>
auto
operator
()(
int
m
,
const
HostTensor
<
AccDataType
>&
acc
)
{
HostTensor
<
OutDataType
>
o
(
acc
.
get_lengths
(),
acc
.
get_strides
());
operator
()(
m
,
o
,
acc
);
return
o
;
}
};
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
InvRmsDataType
>
typename
InvRmsDataType
,
typename
Epilogue
=
reference_rmsnorm2d_default_epilogue
>
void
reference_rmsnorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
void
reference_rmsnorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
InvRmsDataType
>&
invRms_m
,
HostTensor
<
InvRmsDataType
>&
invRms_m
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
,
Epilogue
epilogue_functor
=
{})
{
{
auto
rmsnorm2d_fwd_func
=
[
&
](
auto
m
)
{
auto
rmsnorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
...
@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
...
@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if
constexpr
(
!
std
::
is_same_v
<
InvRmsDataType
,
ck_tile
::
null_type
>
)
if
constexpr
(
!
std
::
is_same_v
<
InvRmsDataType
,
ck_tile
::
null_type
>
)
invRms_m
(
m
)
=
ck_tile
::
type_convert
<
InvRmsDataType
>
(
divisor
);
invRms_m
(
m
)
=
ck_tile
::
type_convert
<
InvRmsDataType
>
(
divisor
);
HostTensor
<
ComputeDataType
>
acc
(
x_m_n
.
get_lengths
(),
x_m_n
.
get_strides
());
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
auto
y
=
x
*
divisor
*
gamma
;
acc
(
m
,
n
)
=
x
*
divisor
*
gamma
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
);
}
}
epilogue_functor
(
m
,
y_m_n
,
acc
);
};
};
make_ParallelTensorFunctor
(
rmsnorm2d_fwd_func
,
invRms_m
.
mDesc
.
get_lengths
()[
0
])(
make_ParallelTensorFunctor
(
rmsnorm2d_fwd_func
,
invRms_m
.
mDesc
.
get_lengths
()[
0
])(
...
...
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
View file @
3c5717df
...
@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
...
@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
// scale = amax / 127 for int8
// scale = amax / 127 for int8
auto
v_scale
=
type_convert
<
XDataType
>
(
scale_m
(
m
));
auto
v_scale
=
type_convert
<
XDataType
>
(
scale_m
(
m
));
auto
v_qx
=
v_x
/
v_scale
;
auto
v_qx
=
v_x
/
v_scale
;
qx_m_n
(
m
,
n
)
=
saturates
<
QXDataType
>
{}(
v_qx
);
qx_m_n
(
m
,
n
)
=
type_convert
<
QXDataType
>
(
saturates
<
QXDataType
>
{}(
v_qx
)
)
;
}
}
};
};
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
View file @
3c5717df
...
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
...
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
...
@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
float
rtn
;
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
...
@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
...
@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
});
});
// compute absmax, each-thread->cross-lane->cross-warp
// compute absmax, each-thread->cross-lane->cross-warp
auto
absmax
=
block_reduce2d
(
auto
absmax
=
[
&
]()
{
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
&&
x_size_per_row
%
2
==
0
)
{
return
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax3_func
,
sequence
<
1
,
2
>
{});
}
else
{
return
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
}
}();
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
View file @
3c5717df
...
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
...
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
...
@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
float
rtn
;
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
...
@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
...
@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
});
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
&&
x_size_per_row
%
2
==
0
)
block_reduce2d
(
y
,
absmax
,
reduce_absmax3_func
,
sequence
<
1
,
2
>
{});
else
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
if
constexpr
(
kSaveX
)
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
...
...
Prev
1
…
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment