Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cc485d80
Commit
cc485d80
authored
Jan 09, 2025
by
mtgu0705
Browse files
pk_i4_t enabled based on commit "remove gfx12 targets from daily builds with rocm6.2 (#1560)"
parent
cfac9497
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1076 additions
and
28 deletions
+1076
-28
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+52
-22
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_pk4.hpp
.../gpu/thread/threadwise_tensor_slice_transfer_v3r1_pk4.hpp
+905
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+35
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+52
-5
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+30
-0
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
cc485d80
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename
DstDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
_
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
_
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
...
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
static
constexpr
auto
SrcScalarPerVector
=
Number
<
SrcScalarPerVector_
/
PackedSize
>
{};
static
constexpr
auto
DstScalarPerVector
=
Number
<
DstScalarPerVector_
/
PackedSize
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
...
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_
(
src_element_op
),
src_element_op_
(
src_element_op
),
dst_element_op_
(
dst_element_op
)
dst_element_op_
(
dst_element_op
)
{
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>
,
"SrcData != DstData"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"
);
static_assert
(
SrcVectorDim
==
DstVectorDim
,
"pk_i4_t does not support transpose"
);
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
(
SrcScalarPerVector
_
)
==
0
,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
...
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
true
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
dst_vector_type
op_r_v
;
...
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
else
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
else
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
}
return
1
;
else
{
return
1
;
}
};
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
...
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
true
)};
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
...
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
});
#else
#else
// OOB Check
// OOB Check
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
{
static_assert
(
!
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
,
"in-register transpose is not supported for pk_i4_t"
);
// each transpose does
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
}
else
else
{
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
constexpr
auto
packed_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
packed_access_lengths
=
SliceLengths
{}
/
packed_per_access
;
static_ford
<
decltype
(
packed_access_lengths
)
>
{}([
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
});
}
}
...
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim
// src scalar per access on each dim
// TODO: don't use this
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
});
// copy data from dst_vector_container to dst_buf
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
dst_coord_
.
GetOffset
()
/
PackedSize
,
is_dst_valid
,
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
...
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
{
// 1st stage of transforms
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_pk4.hpp
0 → 100644
View file @
cc485d80
This diff is collapsed.
Click to expand it.
include/ck/utility/amd_buffer_addressing.hpp
View file @
cc485d80
...
@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
pk_i4_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
...
...
include/ck/utility/data_type.hpp
View file @
cc485d80
...
@@ -13,6 +13,15 @@ using int4_t = _BitInt(4);
...
@@ -13,6 +13,15 @@ using int4_t = _BitInt(4);
using
f8_t
=
_BitInt
(
8
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
// custom data type - pack int4 data
struct
pk_i4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
};
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_type
;
struct
vector_type
;
...
@@ -149,6 +158,13 @@ struct scalar_type<int4_t>
...
@@ -149,6 +158,13 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
template
<
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
...
@@ -990,6 +1006,20 @@ struct vector_type<T, 256>
...
@@ -990,6 +1006,20 @@ struct vector_type<T, 256>
}
}
};
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
pk_i4_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
pk_i4_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
using
int64_t
=
long
;
using
int64_t
=
long
;
// fp64
// fp64
...
@@ -1060,6 +1090,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
...
@@ -1060,6 +1090,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
cc485d80
...
@@ -324,6 +324,18 @@ struct Tensor
...
@@ -324,6 +324,18 @@ struct Tensor
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
(
mDesc
.
GetElementSpaceSize
()
+
1
)
/
2
;
}
else
{
return
mDesc
.
GetElementSpaceSize
();
}
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
0
);
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
0
);
}
...
@@ -469,29 +481,64 @@ struct Tensor
...
@@ -469,29 +481,64 @@ struct Tensor
template
<
typename
...
Is
>
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
;
}
else
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
}
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
T
&
operator
()(
Is
...
is
)
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
const
T
&
operator
()(
Is
...
is
)
const
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
cc485d80
...
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
...
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
}
}
};
};
template
<
>
struct
GeneratorTensor_1
<
ck
::
pk_i4_t
>
{
int8_t
value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
t
=
value
+
8
;
ck
::
pk_i4_t
r
=
((
t
<<
4
)
+
t
)
&
0xff
;
return
r
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_2
struct
GeneratorTensor_2
{
{
...
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
...
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
}
}
};
};
template
<
>
struct
GeneratorTensor_2
<
ck
::
pk_i4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
hi
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
int
lo
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
ck
::
pk_i4_t
r
=
((
hi
<<
4
)
+
lo
)
&
0xff
;
return
r
;
}
};
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment