Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
418baed3
Commit
418baed3
authored
Feb 12, 2025
by
coderfeli
Browse files
moe gemm1 scaleready
parent
b02c0b82
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
13 deletions
+22
-13
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+10
-8
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+4
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
.../thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
+8
-2
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
418baed3
...
...
@@ -71,7 +71,7 @@ struct MulABScale
(
void
)
d2
;
// for gate, no d2 needed
(
void
)
d0
;
(
void
)
d1
;
const
float
x0_f
=
c
;
const
float
x0_f
=
c
*
d1
*
d0
;
// const float x0_f = c;
e
=
ck
::
type_convert
<
EDataType
>
(
x0_f
);
}
...
...
@@ -286,9 +286,9 @@ int main(int argc, char* argv[])
case
1
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
2
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
1
,
3
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
1
,
3
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
1
,
3
});
break
;
case
2
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
...
...
@@ -304,6 +304,9 @@ int main(int argc, char* argv[])
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
0.0
,
1.0
});
}
d0_t_n
.
savetxt
(
"d0_t_n.txt"
,
"int"
);
d1_e_n
.
savetxt
(
"d1_e_n.txt"
,
"int"
);
d2_m_n
.
savetxt
(
"d2_m_n.txt"
,
"int"
);
DeviceMem
sorted_token_ids_dev
(
sizeof
(
ck
::
index_t
)
*
sorted_token_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
expert_ids_dev
(
sizeof
(
ck
::
index_t
)
*
expert_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_t_k
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -325,8 +328,6 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
...
...
@@ -352,7 +353,7 @@ int main(int argc, char* argv[])
K
,
StrideA
,
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
I0
,
I0
,
I0
}
,
StrideDs
,
StrideE
,
KBatch
,
a_element_op
,
...
...
@@ -406,9 +407,10 @@ int main(int argc, char* argv[])
{
const
int
t
=
sorted_token_ids
(
m
);
const
int
e
=
expert_ids
(
m
/
sorted_tile_size
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
m
,
n
),
d2_m_n
(
m
,
n
));
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
),
d2_m_n
(
m
,
n
));
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
418baed3
...
...
@@ -1401,7 +1401,7 @@ struct GridwiseMoeGemmGather
if
(
i
.
value
==
1
)
{
ptr_
+=
expert_id
*
(
problem
.
StrideDs
[
1
]
?
problem
.
StrideDs
[
1
]
*
problem
.
N
:
1
);
// if ( threadIdx.x ==0)
// if ( threadIdx.x
% 16
==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -1448,10 +1448,11 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I
2
];
const
float
*
p_sorted_weights
=
p_ds_grid
[
I
0
];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
0
;
scatter_weights
(
m0
)
=
p_sorted_weights
[
c_token_pos
+
m0
];
scatter_weights
(
m0
)
=
p_sorted_weights
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3_scatter
<
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
View file @
418baed3
...
...
@@ -176,10 +176,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_
[
i
]);
oob_val
=
oob_val
&
is_src_valid
;
if
(
i
.
value
==
ScatterWeightIdx
)
if
(
i
.
value
==
ScatterWeightIdx
)
{
static_assert
(
SrcScalarPerVectors
{}[
Number
<
2
>
{}]
==
1
,
"scatter weight dim, should only one vec"
);
static_assert
(
SrcScalarPerVectors
{}[
Number
<
ScatterWeightIdx
>
{}]
==
1
,
"scatter weight dim, should only one vec"
);
constexpr
auto
iScatter
=
SrcSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number<iScatter>{}));
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
float
>()(
j
)
=
scatter_weights_
(
Number
<
iScatter
>
{});
});
}
...
...
@@ -189,11 +191,15 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
const
auto
tmp
=
src_bufs
[
i
].
template
Get
<
DataType
>(
src_coords_
[
i
].
GetOffset
(),
true
);
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
DataType
>()(
j
)
=
tmp
;
});
}
else
{
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value);
src_vectors
(
i
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
true
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment