Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
041ac4c9
Commit
041ac4c9
authored
Dec 06, 2024
by
Jing Zhang
Browse files
add pk_i4_t as a struct
parent
23f99eb4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
46 additions
and
25 deletions
+46
-25
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+2
-8
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+8
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+30
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+2
-8
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+2
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+2
-2
No files found.
include/ck/tensor/static_tensor.hpp
View file @
041ac4c9
...
@@ -166,10 +166,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -166,10 +166,7 @@ struct StaticTensorTupleOfVectorBuffer
// Get X
// Get X
// Idx is for S, not X. Idx should be aligned with X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
template
<
typename
X
,
typename
Idx
,
typename
Idx
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
{
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
...
@@ -200,10 +197,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -200,10 +197,7 @@ struct StaticTensorTupleOfVectorBuffer
// Set X
// Set X
// Idx is for S, not X. Idx should be aligned with X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
template
<
typename
X
,
typename
Idx
,
typename
Idx
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
{
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
041ac4c9
...
@@ -39,7 +39,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
...
@@ -39,7 +39,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
{
{
#if
0
#if
1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
...
@@ -118,7 +118,7 @@ struct PassThroughPack8
...
@@ -118,7 +118,7 @@ struct PassThroughPack8
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
{
#if
0
#if
1
vector_type
<
half_t
,
8
>
result
;
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
...
@@ -252,6 +252,12 @@ struct PassThrough final : public UnaryOpBase
...
@@ -252,6 +252,12 @@ struct PassThrough final : public UnaryOpBase
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
pk_i4_t
,
pk_i4_t
>
(
pk_i4_t
&
y
,
const
pk_i4_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
{
...
...
include/ck/utility/data_type.hpp
View file @
041ac4c9
...
@@ -13,7 +13,14 @@ using half_t = _Float16;
...
@@ -13,7 +13,14 @@ using half_t = _Float16;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
pk_i4_t
=
uint8_t
;
//using pk_i4_t = uint8_t;
struct
pk_i4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
{
...
@@ -168,6 +175,13 @@ struct scalar_type<int4_t>
...
@@ -168,6 +175,13 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
template
<
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
template
<
>
struct
scalar_type
<
f8_fnuz_t
>
struct
scalar_type
<
f8_fnuz_t
>
{
{
...
@@ -1047,6 +1061,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
...
@@ -1047,6 +1061,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using
type
=
bf8_ocp_t
::
data_type
;
using
type
=
bf8_ocp_t
::
data_type
;
};
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
struct
non_native_vector_base
<
T
,
T
,
...
@@ -1166,6 +1186,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
...
@@ -1166,6 +1186,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
pk_i4_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
pk_i4_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
// non-native vector_type implementation
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
...
@@ -1868,6 +1896,7 @@ using bf8x64_t = bf8x64_fnuz_t;
...
@@ -1868,6 +1896,7 @@ using bf8x64_t = bf8x64_fnuz_t;
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
// u8
// u8
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
041ac4c9
...
@@ -59,10 +59,7 @@ struct DynamicBuffer
...
@@ -59,10 +59,7 @@ struct DynamicBuffer
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
template
<
typename
X
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
{
// X contains multiple T
// X contains multiple T
...
@@ -204,10 +201,7 @@ struct DynamicBuffer
...
@@ -204,10 +201,7 @@ struct DynamicBuffer
element_space_size_
/
PackedSize
);
element_space_size_
/
PackedSize
);
}
}
template
<
typename
X
,
template
<
typename
X
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
...
...
include/ck/utility/static_buffer.hpp
View file @
041ac4c9
...
@@ -115,8 +115,7 @@ struct StaticBufferTupleOfVector
...
@@ -115,8 +115,7 @@ struct StaticBufferTupleOfVector
// Get X
// Get X
// i is offset of S, not X. i should be aligned to X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
index_t
I
,
index_t
I
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
{
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
@@ -133,8 +132,7 @@ struct StaticBufferTupleOfVector
...
@@ -133,8 +132,7 @@ struct StaticBufferTupleOfVector
// Set X
// Set X
// i is offset of S, not X. i should be aligned to X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
index_t
I
,
index_t
I
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
{
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
041ac4c9
...
@@ -76,7 +76,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -76,7 +76,7 @@ struct ReferenceGemm : public device::BaseOperator
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
pk_i4_t
>
)
else
if
constexpr
(
is_same_v
<
ADataType
,
pk_i4_t
>
)
{
{
pk_i4
_t
i4x2
=
arg
.
a_m_k_
(
m
,
k
);
uint8
_t
i4x2
=
arg
.
a_m_k_
(
m
,
k
)
.
data
;
int8_t
i4
=
0
;
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
i4
=
(
i4x2
>>
0
)
&
0xf
;
...
@@ -97,7 +97,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -97,7 +97,7 @@ struct ReferenceGemm : public device::BaseOperator
}
}
else
if
constexpr
(
is_same_v
<
BDataType
,
pk_i4_t
>
)
else
if
constexpr
(
is_same_v
<
BDataType
,
pk_i4_t
>
)
{
{
pk_i4
_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
);
uint8
_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
)
.
data
;
int8_t
i4
=
0
;
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
i4
=
(
i4x2
>>
0
)
&
0xf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment