Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3a84f68e
Commit
3a84f68e
authored
Dec 16, 2024
by
Jing Zhang
Browse files
fixed
parent
bf545630
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
36 deletions
+33
-36
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+2
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+0
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+5
-13
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+1
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+8
-0
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+16
-19
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
3a84f68e
...
...
@@ -37,6 +37,7 @@ struct DeviceGemmV2 : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteA
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
3a84f68e
...
...
@@ -410,9 +410,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
// const index_t BK00 = BK0 / BK01;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
auto
b_grid_desc_bk00_n_bk01_bk1_permute
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
3a84f68e
...
...
@@ -1137,7 +1137,6 @@ struct ThreadwiseTensorSliceTransfer_v4
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_assert
(
false
,
""
);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
3a84f68e
...
...
@@ -82,9 +82,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>
,
"SrcData != DstData"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1
for packed data type
"
);
static_assert
(
SrcVectorDim
==
DstVectorDim
,
"pk_i4_t does not support transpose"
);
}
...
...
@@ -234,8 +234,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
static_assert
(
elem_op_vec_len
==
1
,
"elem_op_vec_len != 1"
);
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
true
)};
...
...
@@ -300,13 +298,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
TransferDataFromSrcThreadScratchToDstThreadScratch
(
Number
<
ThreadScratchId
>
thread_scratch_id
)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_assert
(
false
,
""
);
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
#else
#if 1
// OOB Check
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector_
>
{},
Number
<
nDim
>
{});
...
...
@@ -369,7 +364,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
vector_t
>(
src_data_idx_seq
,
op_r_v
);
});
#endif
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
...
...
@@ -381,9 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
//"transpose is not allowed for pk_i4_t");
#if 1
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
,
"transpose is not allowed for pk_i4_t"
);
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...
...
@@ -441,7 +434,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
transpose_vectors
<
DstData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
#endif
}
else
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
3a84f68e
...
...
@@ -429,6 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
pk_i4_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
...
...
include/ck/utility/data_type.hpp
View file @
3a84f68e
...
...
@@ -1893,6 +1893,14 @@ using bf8x32_t = bf8x32_fnuz_t;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
3a84f68e
...
...
@@ -29,13 +29,6 @@ struct DynamicBuffer
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
T
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
...
...
@@ -59,7 +52,11 @@ struct DynamicBuffer
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
>
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
// X contains multiple T
...
...
@@ -85,18 +82,14 @@ struct DynamicBuffer
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
,
invalid_element_value_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
else
...
...
@@ -198,10 +191,14 @@ struct DynamicBuffer
dst_buf
.
p_data_
,
dst_offset
,
is_valid_element
,
element_space_size_
/
PackedSize
);
element_space_size_
);
}
template
<
typename
X
>
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
...
...
@@ -229,7 +226,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
&&
...
...
@@ -381,7 +378,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
...
...
@@ -420,7 +417,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
if
(
is_valid_element
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment