Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6582aedc
Commit
6582aedc
authored
Mar 31, 2022
by
binmakeswell
Browse files
fix format (#583)
parent
f08fc17f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
575 additions
and
591 deletions
+575
-591
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
+575
-591
No files found.
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
View file @
6582aedc
#include <torch/extension.h>
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include <torch/extension.h>
#include "block_reduce.h"
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
T
pack
[
pack_size
];
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
}
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack
);
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
T
pack
[
pack_size
];
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
pack
);
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
pack
);
}
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
pack
);
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
T
pack
[
pack_size
];
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockStore
(
ts_store
).
Store
(
dst_row1
+
idx
,
pack
);
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row2
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row1
+
idx
,
pack
);
}
BlockStore
(
ts_store
).
Store
(
dst_row2
+
idx
,
pack
);
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
assert
(
cols
%
pack_size
==
0
);
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
const
int
bpack_size
=
block_size
*
pack_size
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
Block
Store
<
T
,
block_size
,
pack_size
,
typedef
cub
::
Block
Load
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockLoad
;
__shared__
typename
Block
Store
::
TempStorage
ts_
store
;
__shared__
typename
Block
Load
::
TempStorage
ts_
load
;
int
tps
=
threadIdx
.
x
*
pack_size
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
T
pack1
[
pack_size
],
pack2
[
pack_size
];
BlockStore
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
BlockLoad
(
ts_load
).
Load
(
dst_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
dst_row2
+
idx
,
pack2
);
#pragma unroll
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
T
pack1
[
pack_size
],
pack2
[
pack_size
];
pack1
[
i
]
+=
pack2
[
i
];
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
}
BlockLoad
(
ts_load
).
Load
(
dst_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
dst_row2
+
idx
,
pack2
);
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
pack1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack1
[
i
]
+=
pack2
[
i
];
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
pack1
);
__device__
void
moe_cb_one_fwd
(
}
T
*
src_row
,
T
*
dst_row
,
}
const
T
weight
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
template
<
typename
T
,
int
block_size
,
int
pack_size
>
const
int
bpack_size
=
block_size
*
pack_size
;
__device__
void
moe_cb_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
T
weight
,
const
int
cols
)
{
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
assert
(
cols
%
pack_size
==
0
);
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
const
int
bpack_size
=
block_size
*
pack_size
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
Block
Store
<
T
,
block_size
,
pack_size
,
typedef
cub
::
Block
Load
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockLoad
;
__shared__
typename
Block
Store
::
TempStorage
ts_
store
;
__shared__
typename
Block
Load
::
TempStorage
ts_
load
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
pack
[
pack_size
];
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockStore
;
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
)
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
#pragma unroll
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
T
pack
[
pack_size
]
;
pack
[
i
]
*=
weight
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
}
BlockLoad
(
ts_load
).
Load
(
src_row
+
idx
,
pack
);
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack
[
i
]
*=
weight
;
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack
);
__device__
void
moe_cb_one_bwd
(
}
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
T
*
weight_grad
,
}
const
T
weight
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
template
<
typename
T
,
int
block_size
,
int
pack_size
>
const
int
bpack_size
=
block_size
*
pack_size
;
__device__
void
moe_cb_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
T
*
weight_grad
,
const
T
weight
,
const
int
cols
)
{
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
assert
(
cols
%
pack_size
==
0
);
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
const
int
bpack_size
=
block_size
*
pack_size
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
Block
Store
<
T
,
block_size
,
pack_size
,
typedef
cub
::
Block
Load
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockLoad
;
__shared__
typename
Block
Store
::
TempStorage
ts_
store
;
__shared__
typename
Block
Load
::
TempStorage
ts_
load
;
int
tps
=
threadIdx
.
x
*
pack_size
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
T
grad
[
pack_size
],
tokens
[
pack_size
];
BlockStore
;
float
thread_sum
=
0
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
tks_row
+
idx
,
tokens
);
#pragma unroll
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
T
grad
[
pack_size
],
tokens
[
pack_size
];
thread_sum
+=
grad
[
i
]
*
tokens
[
i
];
float
thread_sum
=
0
;
grad
[
i
]
*=
weight
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
}
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
tks_row
+
idx
,
tokens
);
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
grad
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
thread_sum
+=
grad
[
i
]
*
tokens
[
i
];
grad
[
i
]
*=
weight
;
}
}
blockReduce
<
ReduceType
::
kSum
,
1
>
(
&
thread_sum
);
BlockStore
(
ts_store
).
Store
(
src_row
+
idx
,
grad
);
}
if
(
threadIdx
.
x
==
0
)
blockReduce
<
ReduceType
::
kSum
,
1
>
(
&
thread_sum
);
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
if
(
threadIdx
.
x
==
0
)
__device__
void
moe_cb_two_fwd
(
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
}
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
template
<
typename
T
,
int
block_size
,
int
pack_size
>
const
int
bpack_size
=
block_size
*
pack_size
;
__device__
void
moe_cb_two_fwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
assert
(
cols
%
pack_size
==
0
);
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
const
int
bpack_size
=
block_size
*
pack_size
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
Block
Store
<
T
,
block_size
,
pack_size
,
typedef
cub
::
Block
Load
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockLoad
;
__shared__
typename
Block
Store
::
TempStorage
ts_
store
;
__shared__
typename
Block
Load
::
TempStorage
ts_
load
;
int
tps
=
threadIdx
.
x
*
pack_size
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
T
pack1
[
pack_size
],
pack2
[
pack_size
];
BlockStore
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
BlockLoad
(
ts_load
).
Load
(
src_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
src_row2
+
idx
,
pack2
);
#pragma unroll
int
tps
=
threadIdx
.
x
*
pack_size
;
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
T
pack1
[
pack_size
],
pack2
[
pack_size
];
pack1
[
i
]
=
pack1
[
i
]
*
weight1
+
pack2
[
i
]
*
weight2
;
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
}
BlockLoad
(
ts_load
).
Load
(
src_row1
+
idx
,
pack1
);
BlockLoad
(
ts_load
).
Load
(
src_row2
+
idx
,
pack2
);
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
pack1
[
i
]
=
pack1
[
i
]
*
weight1
+
pack2
[
i
]
*
weight2
;
}
}
BlockStore
(
ts_store
).
Store
(
dst_row
+
idx
,
pack1
);
}
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_two_bwd
(
__device__
void
moe_cb_two_bwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad1
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad1
,
T
*
weight_grad2
,
T
*
weight_grad2
,
const
T
weight1
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
T
weight2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockLoad
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_LOAD_VECTORIZE
>
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
typedef
cub
::
BlockStore
<
T
,
block_size
,
pack_size
,
cub
::
BLOCK_STORE_VECTORIZE
>
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
int
tps
=
threadIdx
.
x
*
pack_size
;
int
tps
=
threadIdx
.
x
*
pack_size
;
T
grad
[
pack_size
],
tokens1
[
pack_size
],
tokens2
[
pack_size
],
T
grad
[
pack_size
],
tokens1
[
pack_size
],
tokens2
[
pack_size
],
sgrad1
[
pack_size
],
sgrad1
[
pack_size
],
sgrad2
[
pack_size
];
sgrad2
[
pack_size
];
float
thread_sum
[
2
]
=
{
0
,
0
};
float
thread_sum
[
2
]
=
{
0
,
0
};
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
+
tps
<
cols
;
idx
+=
bpack_size
)
{
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
dst_row
+
idx
,
grad
);
BlockLoad
(
ts_load
).
Load
(
tks_row1
+
idx
,
tokens1
);
BlockLoad
(
ts_load
).
Load
(
tks_row1
+
idx
,
tokens1
);
BlockLoad
(
ts_load
).
Load
(
tks_row2
+
idx
,
tokens2
);
BlockLoad
(
ts_load
).
Load
(
tks_row2
+
idx
,
tokens2
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
pack_size
;
++
i
)
{
thread_sum
[
0
]
+=
grad
[
i
]
*
tokens1
[
i
];
thread_sum
[
0
]
+=
grad
[
i
]
*
tokens1
[
i
];
thread_sum
[
1
]
+=
grad
[
i
]
*
tokens2
[
i
];
thread_sum
[
1
]
+=
grad
[
i
]
*
tokens2
[
i
];
sgrad1
[
i
]
=
weight1
*
grad
[
i
];
sgrad1
[
i
]
=
weight1
*
grad
[
i
];
sgrad2
[
i
]
=
weight2
*
grad
[
i
];
sgrad2
[
i
]
=
weight2
*
grad
[
i
];
}
BlockStore
(
ts_store
).
Store
(
src_row1
+
idx
,
sgrad1
);
BlockStore
(
ts_store
).
Store
(
src_row2
+
idx
,
sgrad2
);
}
}
blockReduce
<
ReduceType
::
kSum
,
2
>
(
thread_sum
);
BlockStore
(
ts_store
).
Store
(
src_row1
+
idx
,
sgrad1
);
BlockStore
(
ts_store
).
Store
(
src_row2
+
idx
,
sgrad2
);
}
if
(
threadIdx
.
x
==
0
)
blockReduce
<
ReduceType
::
kSum
,
2
>
(
thread_sum
);
*
weight_grad1
=
static_cast
<
T
>
(
thread_sum
[
0
]);
else
if
(
threadIdx
.
x
==
1
)
*
weight_grad2
=
static_cast
<
T
>
(
thread_sum
[
1
]);
if
(
threadIdx
.
x
==
0
)
*
weight_grad1
=
static_cast
<
T
>
(
thread_sum
[
0
]);
else
if
(
threadIdx
.
x
==
1
)
*
weight_grad2
=
static_cast
<
T
>
(
thread_sum
[
1
]);
}
}
// DISPATCH KERNELS --------------------------------
// DISPATCH KERNELS --------------------------------
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_fwd_selector
(
__device__
void
moe_dpch_fwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
cols
);
src_row
,
dst_row1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row2
,
cols
);
moe_dpch_one_fwd
<
T
,
block_size
,
pack_size
>
(
else
src_row
,
dst_row2
,
cols
);
return
;
else
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_bwd_selector
(
__device__
void
moe_dpch_bwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
cols
);
src_row
,
dst_row1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row2
,
cols
);
moe_dpch_one_bwd
<
T
,
block_size
,
pack_size
>
(
else
src_row
,
dst_row2
,
cols
);
return
;
else
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_dpch_fwd_kernel
(
__global__
void
moe_dpch_fwd_kernel
(
T
*
batch_tokens
,
T
*
expert_input
,
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
mask1
,
int
*
mask2
,
int
*
dest2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
batch_tokens
+
(
row
*
h
),
expert_input
+
(
dest1
[
row
]
*
h
),
batch_tokens
+
(
row
*
h
),
expert_input
+
(
dest2
[
row
]
*
h
),
h
,
mask1
[
row
],
indicator2
);
expert_input
+
(
dest1
[
row
]
*
h
),
expert_input
+
(
dest2
[
row
]
*
h
),
h
,
mask1
[
row
],
indicator2
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_dpch_bwd_kernel
(
__global__
void
moe_dpch_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
tokens_grad
+
(
row
*
h
),
expert_grad
+
(
dest1
[
row
]
*
h
),
tokens_grad
+
(
row
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
h
,
mask1
[
row
],
indicator2
);
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
h
,
mask1
[
row
],
indicator2
);
}
}
// COMBINE KERNELS --------------------------------
// COMBINE KERNELS --------------------------------
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_fwd_selector
(
__device__
void
moe_cb_fwd_selector
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
int
cols
,
const
int
cols
,
const
T
weight1
,
const
T
weight1
,
const
T
weight2
,
const
T
weight2
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
src_row1
,
src_row2
,
dst_row
,
weight1
,
weight2
,
cols
);
weight1
,
weight2
,
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
dst_row
,
weight1
,
cols
);
src_row1
,
dst_row
,
weight1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
src_row2
,
dst_row
,
weight2
,
cols
);
moe_cb_one_fwd
<
T
,
block_size
,
pack_size
>
(
else
src_row2
,
dst_row
,
weight2
,
cols
);
return
;
else
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_bwd_selector
(
__device__
void
moe_cb_bwd_selector
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
int
cols
,
const
int
cols
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
wt_grad1
,
T
*
wt_grad2
,
T
*
wt_grad1
,
T
*
wt_grad2
,
const
T
weight1
,
const
T
weight1
,
const
T
weight2
,
const
T
weight2
,
const
int
indicator1
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
src_row1
,
src_row2
,
dst_row
,
tks_row1
,
tks_row2
,
wt_grad1
,
tks_row1
,
tks_row2
,
wt_grad1
,
wt_grad2
,
wt_grad2
,
weight1
,
weight2
,
cols
);
weight1
,
weight2
,
cols
);
else
if
(
indicator1
!=
0
)
else
if
(
indicator1
!=
0
)
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
dst_row
,
tks_row1
,
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
wt_grad1
,
weight1
,
cols
);
src_row1
,
dst_row
,
tks_row1
,
wt_grad1
,
weight1
,
cols
);
else
if
(
indicator2
!=
0
)
else
if
(
indicator2
!=
0
)
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
src_row2
,
dst_row
,
tks_row2
,
moe_cb_one_bwd
<
T
,
block_size
,
pack_size
>
(
wt_grad2
,
weight2
,
cols
);
src_row2
,
dst_row
,
tks_row2
,
wt_grad2
,
weight2
,
cols
);
else
else
return
;
return
;
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_cb_fwd_kernel
(
T
*
expert_tokens
,
T
*
combine_tokens
,
__global__
void
moe_cb_fwd_kernel
(
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
int
*
dest2
,
const
int
e
,
const
int
c
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
T
*
row_log
=
logits
+
(
row
*
e
);
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_cb_fwd_selector
<
T
,
block_size
,
pack_size
>
(
T
*
row_log
=
logits
+
(
row
*
e
);
expert_tokens
+
(
dest1
[
row
]
*
h
),
expert_tokens
+
(
dest2
[
row
]
*
h
),
moe_cb_fwd_selector
<
T
,
block_size
,
pack_size
>
(
combine_tokens
+
(
row
*
h
),
h
,
row_log
[
eid1
],
row_log
[
eid2
],
mask1
[
row
],
expert_tokens
+
(
dest1
[
row
]
*
h
),
expert_tokens
+
(
dest2
[
row
]
*
h
),
indicator2
);
combine_tokens
+
(
row
*
h
),
h
,
row_log
[
eid1
],
row_log
[
eid2
],
mask1
[
row
],
indicator2
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__global__
void
moe_cb_bwd_kernel
(
__global__
void
moe_cb_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
logits
,
T
*
logits_grad
,
int
*
mask1
,
T
*
logits
,
T
*
logits_grad
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_cb_bwd_selector
<
T
,
block_size
,
pack_size
>
(
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
moe_cb_bwd_selector
<
T
,
block_size
,
pack_size
>
(
tokens_grad
+
(
row
*
h
),
h
,
tks
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest1
[
row
]
*
h
),
expert_grad
+
(
dest2
[
row
]
*
h
),
tks
+
(
dest2
[
row
]
*
h
),
row_grad
+
eid1
,
row_grad
+
eid2
,
row_log
[
eid1
],
tokens_grad
+
(
row
*
h
),
h
,
row_log
[
eid2
],
mask1
[
row
],
indicator2
);
tks
+
(
dest1
[
row
]
*
h
),
tks
+
(
dest2
[
row
]
*
h
),
row_grad
+
eid1
,
row_grad
+
eid2
,
row_log
[
eid1
],
row_log
[
eid2
],
mask1
[
row
],
indicator2
);
}
}
//CUMSUM KERNEL --------------------------------
// CUMSUM KERNEL --------------------------------
template
<
int
block_size
,
int
pack_size
>
template
<
int
block_size
,
int
pack_size
>
__global__
void
cumsum_kernel
(
__global__
void
cumsum_kernel
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
int
*
inputs
,
int
*
outputs
,
const
int
e
)
{
const
int
s
,
const
int
e
)
{
assert
(
s
%
pack_size
==
0
);
assert
(
s
%
pack_size
==
0
);
constexpr
int
bpack_size
=
block_size
*
pack_size
;
constexpr
int
bpack_size
=
block_size
*
pack_size
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
__shared__
int
temp
[
block_size
+
1
];
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
int
pack
[
pack_size
];
__shared__
int
temp
[
block_size
+
1
];
int
pack
[
pack_size
];
for
(
int
idx
=
0
;
idx
<
s
;
idx
+=
bpack_size
)
{
for
(
int
idx
=
0
;
idx
<
s
;
idx
+=
bpack_size
)
{
int
offset
=
1
;
int
offset
=
1
;
if
(
idx
+
tps
<
s
)
{
if
(
idx
+
tps
<
s
)
{
temp
[
tid
]
=
inputs
[
tps
*
e
+
bid
];
temp
[
tid
]
=
inputs
[
tps
*
e
+
bid
];
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
pack
[
i
]
=
inputs
[(
tps
+
i
)
*
e
+
bid
];
pack
[
i
]
=
inputs
[(
tps
+
i
)
*
e
+
bid
];
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
pack_size
;
++
i
)
{
temp
[
tid
]
+=
pack
[
i
];
temp
[
tid
]
+=
pack
[
i
];
}
}
}
}
for
(
int
i
=
block_size
>>
1
;
i
>
0
;
i
>>=
1
)
{
for
(
int
i
=
block_size
>>
1
;
i
>
0
;
i
>>=
1
)
{
__syncthreads
();
__syncthreads
();
if
(
tid
<
i
)
{
if
(
tid
<
i
)
{
int
j
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
j
=
offset
*
(
2
*
tid
+
1
)
-
1
;
temp
[
j
+
offset
]
+=
temp
[
j
];
temp
[
j
+
offset
]
+=
temp
[
j
];
}
}
offset
<<=
1
;
offset
<<=
1
;
}
if
(
tid
==
0
)
{
temp
[
block_size
]
=
temp
[
block_size
-
1
];
temp
[
block_size
-
1
]
=
0
;
}
for
(
int
i
=
1
;
i
<
block_size
;
i
<<=
1
)
{
offset
>>=
1
;
__syncthreads
();
if
(
tid
<
i
)
{
int
j
=
offset
*
(
2
*
tid
+
1
)
-
1
,
k
=
j
+
offset
,
ts
=
temp
[
j
];
temp
[
j
]
=
temp
[
k
];
temp
[
k
]
+=
ts
;
}
}
__syncthreads
();
if
(
tid
==
0
)
temp
[
0
]
=
temp
[
block_size
];
__syncthreads
();
if
(
idx
+
tps
<
s
)
{
temp
[
tid
+
1
]
+=
last_sum
;
#pragma unroll
for
(
int
i
=
pack_size
-
1
;
i
>
0
;
--
i
)
{
outputs
[(
tps
+
i
)
*
e
+
bid
]
=
temp
[
tid
+
1
];
temp
[
tid
+
1
]
-=
pack
[
i
];
}
outputs
[
tps
*
e
+
bid
]
=
temp
[
tid
+
1
];
}
__syncthreads
();
last_sum
+=
temp
[
0
];
inputs
+=
bpack_size
*
e
;
outputs
+=
bpack_size
*
e
;
}
}
if
(
tid
==
0
)
{
temp
[
block_size
]
=
temp
[
block_size
-
1
];
temp
[
block_size
-
1
]
=
0
;
}
for
(
int
i
=
1
;
i
<
block_size
;
i
<<=
1
)
{
offset
>>=
1
;
__syncthreads
();
if
(
tid
<
i
)
{
int
j
=
offset
*
(
2
*
tid
+
1
)
-
1
,
k
=
j
+
offset
,
ts
=
temp
[
j
];
temp
[
j
]
=
temp
[
k
];
temp
[
k
]
+=
ts
;
}
}
__syncthreads
();
if
(
tid
==
0
)
temp
[
0
]
=
temp
[
block_size
];
__syncthreads
();
if
(
idx
+
tps
<
s
)
{
temp
[
tid
+
1
]
+=
last_sum
;
#pragma unroll
for
(
int
i
=
pack_size
-
1
;
i
>
0
;
--
i
)
{
outputs
[(
tps
+
i
)
*
e
+
bid
]
=
temp
[
tid
+
1
];
temp
[
tid
+
1
]
-=
pack
[
i
];
}
outputs
[
tps
*
e
+
bid
]
=
temp
[
tid
+
1
];
}
__syncthreads
();
last_sum
+=
temp
[
0
];
inputs
+=
bpack_size
*
e
;
outputs
+=
bpack_size
*
e
;
}
}
}
//LAUNCH FUNCTIONS --------------------------------
// LAUNCH FUNCTIONS --------------------------------
template
<
typename
T
>
template
<
typename
T
>
void
moe_dpch_fwd_launch
(
void
moe_dpch_fwd_launch
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
if
(
h
<
256
)
moe_dpch_fwd_kernel
<
T
,
32
,
4
>
if
(
h
<
256
)
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
512
)
else
if
(
h
<
512
)
moe_dpch_fwd_kernel
<
T
,
32
,
8
>
moe_dpch_fwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
1024
)
else
if
(
h
<
1024
)
moe_dpch_fwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
32
,
16
>
else
if
(
h
<
2048
)
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_fwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
2048
)
else
moe_dpch_fwd_kernel
<
T
,
64
,
16
>
moe_dpch_fwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
64
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
moe_dpch_fwd_kernel
<
T
,
128
,
16
>
<<<
s
,
128
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
moe_dpch_bwd_launch
(
void
moe_dpch_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
if
(
h
<
256
)
const
int
s
,
const
int
h
)
{
moe_dpch_bwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
if
(
h
<
256
)
else
if
(
h
<
512
)
moe_dpch_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
32
,
8
>
else
if
(
h
<
512
)
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
1024
)
else
if
(
h
<
1024
)
moe_dpch_bwd_kernel
<
T
,
32
,
16
>
moe_dpch_bwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
if
(
h
<
2048
)
else
if
(
h
<
2048
)
moe_dpch_bwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
64
,
16
>
else
<<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
moe_dpch_bwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
else
moe_dpch_bwd_kernel
<
T
,
128
,
16
>
<<<
s
,
128
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
moe_cb_fwd_launch
(
void
moe_cb_fwd_launch
(
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
if
(
h
<
256
)
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
if
(
h
<
256
)
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
e
,
c
,
h
);
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
if
(
h
<
512
)
else
if
(
h
<
512
)
moe_cb_fwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
moe_cb_fwd_kernel
<
T
,
32
,
8
><<<
s
,
32
>>>
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
e
,
c
,
h
);
else
if
(
h
<
1024
)
else
if
(
h
<
1024
)
moe_cb_fwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
moe_cb_fwd_kernel
<
T
,
32
,
16
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
else
if
(
h
<
2048
)
e
,
c
,
h
);
moe_cb_fwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
else
if
(
h
<
2048
)
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
moe_cb_fwd_kernel
<
T
,
64
,
16
><<<
s
,
64
>>>
(
expert_tokens
,
combine_tokens
,
else
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
moe_cb_fwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
e
,
c
,
h
);
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
moe_cb_fwd_kernel
<
T
,
128
,
16
><<<
s
,
128
>>>
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
moe_cb_bwd_launch
(
void
moe_cb_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
logits
,
T
*
tokens_grad
,
T
*
expert_grad
,
T
*
tks
,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
logits
,
T
*
logits_grad
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
int
*
mask1
,
int
*
mask2
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
if
(
h
<
256
)
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
tks
,
if
(
h
<
256
)
logits
,
logits_grad
,
mask1
,
mask2
,
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
dest1
,
dest2
,
e
,
c
,
h
);
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
else
// if (h < 512)
else
// if (h < 512)
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
tks
,
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
logits
,
logits_grad
,
mask1
,
mask2
,
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
dest1
,
dest2
,
e
,
c
,
h
);
// else if (h < 1024)
// else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// else
// dest1, dest2, e, c, h);
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// else
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
}
}
void
cumsum_launch
(
void
cumsum_launch
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
if
(
s
<=
256
)
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
if
(
s
<=
256
)
else
if
(
s
<=
512
)
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
512
,
1
><<<
e
,
512
>>>
(
inputs
,
outputs
,
s
,
e
);
else
if
(
s
<=
512
)
else
if
(
s
<=
1024
)
cumsum_kernel
<
512
,
1
><<<
e
,
512
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
1024
,
1
><<<
e
,
1024
>>>
(
inputs
,
outputs
,
s
,
e
);
else
if
(
s
<=
1024
)
else
if
(
s
<=
2048
)
cumsum_kernel
<
1024
,
1
><<<
e
,
1024
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
1024
,
2
><<<
e
,
1024
>>>
(
inputs
,
outputs
,
s
,
e
);
else
if
(
s
<=
2048
)
else
cumsum_kernel
<
1024
,
2
><<<
e
,
1024
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
1024
,
4
><<<
e
,
1024
>>>
(
inputs
,
outputs
,
s
,
e
);
else
cumsum_kernel
<
1024
,
4
><<<
e
,
1024
>>>
(
inputs
,
outputs
,
s
,
e
);
}
}
// API FUNCTIONS --------------------------------
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch
(
TYPE
)
\
switch
(
TYPE
)
{
\
{
\
case
at
::
ScalarType
::
Float
:
{
\
case
at
::
ScalarType
::
Float
:
\
using
scalar_t
=
float
;
\
{
\
__VA_ARGS__
;
\
using
scalar_t
=
float
;
\
break
;
\
__VA_ARGS__
;
\
}
\
break
;
\
case
at
::
ScalarType
::
Half
:
{
\
}
\
using
scalar_t
=
at
::
Half
;
\
case
at
::
ScalarType
::
Half
:
\
__VA_ARGS__
;
\
{
\
break
;
\
using
scalar_t
=
at
::
Half
;
\
}
\
__VA_ARGS__
;
\
default:
\
break
;
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
}
\
}
default:
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
torch
::
Tensor
moe_dispatch_cuda_forward
(
int
s
,
int
ec
,
int
h
,
}
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
moe_dispatch_cuda_forward
(
torch
::
Tensor
dest_idx
)
{
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
batch_tokens
,
assert
(
h
%
16
==
0
);
torch
::
Tensor
mask
,
auto
res
=
torch
::
zeros
(
torch
::
Tensor
dest_idx
)
{
{
ec
,
h
},
torch
::
dtype
(
batch_tokens
.
dtype
()).
device
(
batch_tokens
.
device
()));
assert
(
h
%
16
==
0
);
auto
k
=
mask
.
size
(
0
);
auto
res
=
torch
::
zeros
({
ec
,
h
},
torch
::
dtype
(
batch_tokens
.
dtype
()).
device
(
batch_tokens
.
device
()));
DISPATCH_FLOAT_AND_HALF
(
auto
k
=
mask
.
size
(
0
);
batch_tokens
.
scalar_type
(),
"moe dispatch forward"
,
moe_dpch_fwd_launch
<
scalar_t
>
(
DISPATCH_FLOAT_AND_HALF
(
batch_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
batch_tokens
.
scalar_type
(),
"moe dispatch forward"
,
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
moe_dpch_fwd_launch
<
scalar_t
>
(
dest_idx
[
0
].
data
<
int
>
(),
batch_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
h
));
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
return
res
;
s
,
h
)
);
return
res
;
}
}
torch
::
Tensor
moe_dispatch_cuda_backward
(
torch
::
Tensor
moe_dispatch_cuda_backward
(
int
s
,
int
ec
,
int
h
,
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
(
auto
res
=
torch
::
zeros
({
s
,
h
},
{
s
,
h
},
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
expert_grad
.
scalar_type
(),
"moe dispatch backward"
,
expert_grad
.
scalar_type
(),
"moe dispatch backward"
,
moe_dpch_bwd_launch
<
scalar_t
>
(
moe_dpch_bwd_launch
<
scalar_t
>
(
res
.
data
<
scalar_t
>
(),
expert_grad
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
expert_grad
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
h
));
s
,
h
)
);
return
res
;
return
res
;
}
}
torch
::
Tensor
moe_combine_cuda_forward
(
torch
::
Tensor
moe_combine_cuda_forward
(
int
s
,
int
e
,
int
c
,
int
h
,
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
logits
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
h
%
16
==
0
);
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
auto
res
=
torch
::
zeros
(
{
s
,
h
},
auto
res
=
torch
::
zeros
({
s
,
h
},
torch
::
dtype
(
expert_tokens
.
dtype
()).
device
(
expert_tokens
.
device
()));
torch
::
dtype
(
expert_tokens
.
dtype
()).
device
(
expert_tokens
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
expert_tokens
.
scalar_type
(),
"moe combine forward"
,
expert_tokens
.
scalar_type
(),
"moe combine forward"
,
moe_cb_fwd_launch
<
scalar_t
>
(
moe_cb_fwd_launch
<
scalar_t
>
(
expert_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
expert_tokens
.
data
<
scalar_t
>
(),
res
.
data
<
scalar_t
>
(),
logits
.
data
<
scalar_t
>
(),
logits
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
e
,
c
,
s
,
e
,
c
,
h
)
h
));
);
return
res
;
return
res
;
}
}
std
::
vector
<
torch
::
Tensor
>
moe_combine_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
int
s
,
int
e
,
int
c
,
int
h
,
moe_combine_cuda_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
assert
(
h
%
16
==
0
);
torch
::
Tensor
dest_idx
)
{
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
h
%
16
==
0
);
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
auto
egrad
=
torch
::
zeros
(
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
{
e
*
c
,
h
},
torch
::
dtype
(
tokens_grad
.
dtype
()).
device
(
tokens_grad
.
device
())),
auto
egrad
=
torch
::
zeros
({
e
*
c
,
h
},
wgrad
=
torch
::
zeros
(
torch
::
dtype
(
tokens_grad
.
dtype
()).
device
(
tokens_grad
.
device
())),
{
s
,
e
},
torch
::
dtype
(
logits
.
dtype
()).
device
(
logits
.
device
()));
wgrad
=
torch
::
zeros
({
s
,
e
},
torch
::
dtype
(
logits
.
dtype
()).
device
(
logits
.
device
()));
auto
k
=
mask
.
size
(
0
);
auto
k
=
mask
.
size
(
0
);
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
tokens_grad
.
scalar_type
(),
"moe combine backward"
,
tokens_grad
.
scalar_type
(),
"moe combine backward"
,
moe_cb_bwd_launch
<
scalar_t
>
(
moe_cb_bwd_launch
<
scalar_t
>
(
tokens_grad
.
data
<
scalar_t
>
(),
egrad
.
data
<
scalar_t
>
(),
tokens_grad
.
data
<
scalar_t
>
(),
egrad
.
data
<
scalar_t
>
(),
expert_tokens
.
data
<
scalar_t
>
(),
expert_tokens
.
data
<
scalar_t
>
(),
logits
.
data
<
scalar_t
>
(),
logits
.
data
<
scalar_t
>
(),
wgrad
.
data
<
scalar_t
>
(),
wgrad
.
data
<
scalar_t
>
(),
mask
[
0
].
data
<
int
>
(),
mask
[
0
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
k
==
1
?
nullptr
:
mask
[
1
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
dest_idx
[
0
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
k
==
1
?
dest_idx
[
0
].
data
<
int
>
()
:
dest_idx
[
1
].
data
<
int
>
(),
s
,
e
,
c
,
s
,
e
,
c
,
h
)
h
));
);
return
{
egrad
,
wgrad
};
return
{
egrad
,
wgrad
};
}
}
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
)
{
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
)
{
assert
(
mask
.
dim
()
==
2
);
assert
(
mask
.
dim
()
==
2
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
const
int
s
=
mask
.
size
(
0
),
e
=
mask
.
size
(
1
);
const
int
s
=
mask
.
size
(
0
),
e
=
mask
.
size
(
1
);
auto
res
=
torch
::
empty
({
s
,
e
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
mask
.
device
()));
auto
res
=
cumsum_launch
(
mask
.
data
<
int
>
(),
res
.
data
<
int
>
(),
s
,
e
);
torch
::
empty
({
s
,
e
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
mask
.
device
()));
cumsum_launch
(
mask
.
data
<
int
>
(),
res
.
data
<
int
>
(),
s
,
e
);
return
res
;
return
res
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment