Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
691e92e1
Commit
691e92e1
authored
Jan 01, 2021
by
Rick Ho
Browse files
fix ptr correctness
parent
7e949a62
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
75 deletions
+73
-75
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+73
-75
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
691e92e1
...
@@ -20,8 +20,9 @@
...
@@ -20,8 +20,9 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define MOE_DEBUG
// #define MOE_BREAKDOWN
// #define MOE_BREAKDOWN
//
#define MOE_DEBUG_SCATTER
#define MOE_DEBUG_SCATTER
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
...
@@ -56,6 +57,12 @@ void batch_gather_kernel(int wid, int* pos,
...
@@ -56,6 +57,12 @@ void batch_gather_kernel(int wid, int* pos,
}
}
}
}
template
<
typename
scalar_t
>
scalar_t
print_first_float
(
scalar_t
*
d_ptr
)
{
scalar_t
v
;
cudaMemcpy
(
&
v
,
d_ptr
,
sizeof
(
scalar_t
),
cudaMemcpyDeviceToHost
);
return
v
;
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
void
moe_cuda_forward_impl
(
...
@@ -80,14 +87,10 @@ void moe_cuda_forward_impl(
...
@@ -80,14 +87,10 @@ void moe_cuda_forward_impl(
scalar_t
*
local_input_buf
,
*
local_output_buf
;
scalar_t
*
local_input_buf
,
*
local_output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
in_feat
));
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
#ifdef MOE_BREAKDOWN
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
timestamp
(
t_malloc
);
fprintf
(
stderr
,
"Malloc time %.3lf us
\n
"
,
getDuration
(
t_init
,
t_malloc
)
*
1e6
);
#endif
int
*
gate
=
new
int
[
batch_size
];
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_count
=
new
int
[
tot_expert
],
*
expert_ptr
=
new
int
[
tot_expert
];
int
*
expert_count
=
new
int
[
tot_expert
],
*
expert_ptr
=
new
int
[
tot_expert
];
...
@@ -96,12 +99,6 @@ void moe_cuda_forward_impl(
...
@@ -96,12 +99,6 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
#ifdef MOE_BREAKDOWN
timestamp
(
t_cpy
);
fprintf
(
stderr
,
"Copy time %.3lf us
\n
"
,
getDuration
(
t_malloc
,
t_cpy
)
*
1e6
);
#endif
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
++
expert_count
[
gate
[
i
]];
}
}
...
@@ -117,6 +114,10 @@ void moe_cuda_forward_impl(
...
@@ -117,6 +114,10 @@ void moe_cuda_forward_impl(
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
}
for
(
int
i
=
batch_size
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
expert_ptr
[
0
]
=
0
;
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
...
@@ -133,25 +134,16 @@ void moe_cuda_forward_impl(
...
@@ -133,25 +134,16 @@ void moe_cuda_forward_impl(
}
}
expert_sz
+=
expert_n
[
i
];
expert_sz
+=
expert_n
[
i
];
}
}
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
#ifdef MOE_DEBUG
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
for
(
int
i
=
0
;
i
<
tot_expert
;
++
i
)
{
if
(
expert_sz
)
{
fprintf
(
stderr
,
"%d %d %d
\n
"
,
cm
->
rank
,
i
,
expert_count
[
i
]);
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
}
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
if
(
cm
->
rank
==
0
)
{
for
(
int
i
=
0
;
i
<
tot_expert
;
++
i
)
{
fprintf
(
stderr
,
"%d "
,
all_expert_count
[
i
]);
}
fprintf
(
stderr
,
"
\n
"
);
}
}
#endif
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_expert
);
timestamp
(
t_expert
);
fprintf
(
stderr
,
"Expert asn time %.3lf us
\n
"
,
getDuration
(
t_
cpy
,
t_expert
)
*
fprintf
(
stderr
,
"Expert asn time %.3lf us
\n
"
,
getDuration
(
t_
init
,
t_expert
)
*
1e6
);
1e6
);
#endif
#endif
...
@@ -159,40 +151,45 @@ void moe_cuda_forward_impl(
...
@@ -159,40 +151,45 @@ void moe_cuda_forward_impl(
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
local_input_buf
);
local_input_buf
);
h
->
sync
(
0
);
h
->
sync
(
0
);
// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));
if
(
cm
->
rank
>
1
)
{
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
if
(
cm
->
size
>
1
)
{
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
if
(
expert_sz
)
{
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
ncclGroupStart
();
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
}
int
recv_ptr
=
0
;
int
recv_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
int
send_id
=
i
+
j
*
num_expert
;
int
idx
=
i
+
j
*
num_expert
;
if
(
expert_count
[
send_id
])
{
if
(
expert_count
[
idx
])
{
ncclSend
(
local_input_buf
+
expert_ptr
[
send_id
]
*
in_feat
,
NCCL_SAFE_CALL
(
ncclSend
(
expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
cm
->
ncclcomm
,
cm
->
ncclcomm
,
h
->
getStream
(
0
));
h
->
getStream
(
0
))
)
;
}
}
i
nt
recv_id
=
i
*
cm
->
size
+
j
;
i
f
(
all_expert_count
[
idx
])
{
if
(
all_expert_count
[
recv_id
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
ncclRecv
(
input_buf
+
recv_ptr
*
in_feat
,
input_buf
+
recv_ptr
*
in_feat
,
all_expert_count
[
recv_
id
]
*
in_feat
*
sizeof
(
scalar_t
),
all_expert_count
[
id
x
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
cm
->
ncclcomm
,
cm
->
ncclcomm
,
h
->
getStream
(
0
));
h
->
getStream
(
0
))
)
;
recv_ptr
+=
all_expert_count
[
recv_
id
];
recv_ptr
+=
all_expert_count
[
id
x
];
}
}
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
ncclGroupEnd
();
}
else
{
}
else
{
input_buf
=
local_input_buf
;
input_buf
=
local_input_buf
;
output_buf
=
local_output_buf
;
}
}
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
...
@@ -202,6 +199,9 @@ void moe_cuda_forward_impl(
...
@@ -202,6 +199,9 @@ void moe_cuda_forward_impl(
1e6
);
1e6
);
#endif
#endif
h
->
sync
(
0
);
// fprintf(stderr, "First %d in %.3f\n", cm->rank, print_first_float(input_buf));
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
...
@@ -209,9 +209,8 @@ void moe_cuda_forward_impl(
...
@@ -209,9 +209,8 @@ void moe_cuda_forward_impl(
continue
;
continue
;
}
}
#ifdef MOE_DEBUG_SCATTER
#ifdef MOE_DEBUG_SCATTER
fprintf
(
stderr
,
"gemm %d sz %d
\n
"
,
i
,
expert_n
[
i
]);
fprintf
(
stderr
,
"worker %d gemm %d sz %d offset %d
\n
"
,
cm
->
rank
,
i
,
expert_n
[
i
],
ptr
);
fprintf
(
stderr
,
"GeMM %d x %d x %d
\n
"
,
out_feat
,
expert_n
[
i
],
// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
in_feat
);
#endif
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
...
@@ -246,37 +245,34 @@ void moe_cuda_forward_impl(
...
@@ -246,37 +245,34 @@ void moe_cuda_forward_impl(
1e6
);
1e6
);
#endif
#endif
if
(
cm
->
rank
>
1
)
{
if
(
cm
->
size
>
1
)
{
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
ncclGroupStart
();
int
send_ptr
=
0
;
int
send_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
int
recv_id
=
i
+
j
*
num_expert
;
int
idx
=
i
+
j
*
num_expert
;
if
(
expert_count
[
recv_id
])
{
if
(
expert_count
[
idx
])
{
ncclRecv
(
local_output_buf
+
expert_ptr
[
recv_id
]
*
in_feat
,
NCCL_SAFE_CALL
(
ncclRecv
(
expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
cm
->
ncclcomm
,
cm
->
ncclcomm
,
h
->
getStream
(
0
));
h
->
getStream
(
0
))
)
;
}
}
i
nt
send_id
=
i
*
cm
->
size
+
j
;
i
f
(
all_expert_count
[
idx
])
{
if
(
all_expert_count
[
send_id
])
{
NCCL_SAFE_CALL
(
ncclSend
(
ncclSend
(
output_buf
+
send_ptr
*
in
_feat
,
output_buf
+
send_ptr
*
out
_feat
,
all_expert_count
[
send_
id
]
*
in
_feat
*
sizeof
(
scalar_t
),
all_expert_count
[
id
x
]
*
out
_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
cm
->
ncclcomm
,
cm
->
ncclcomm
,
h
->
getStream
(
0
));
h
->
getStream
(
0
))
)
;
send_ptr
+=
all_expert_count
[
send_
id
];
send_ptr
+=
all_expert_count
[
id
x
];
}
}
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
ncclGroupEnd
();
}
else
{
local_output_buf
=
output_buf
;
}
}
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
...
@@ -292,13 +288,15 @@ void moe_cuda_forward_impl(
...
@@ -292,13 +288,15 @@ void moe_cuda_forward_impl(
1e6
);
1e6
);
#endif
#endif
cudaFree
(
input_buf
);
if
(
expert_sz
)
{
cudaFree
(
hidden_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
output_buf
);
if
(
cm
->
size
>
1
)
{
if
(
cm
->
rank
>
1
)
{
cudaFree
(
input_buf
);
cudaFree
(
local_in
put_buf
);
cudaFree
(
out
put_buf
);
cudaFree
(
local_output_buf
);
}
}
}
cudaFree
(
local_input_buf
);
cudaFree
(
local_output_buf
);
cudaFree
(
d_pos
);
cudaFree
(
d_pos
);
delete
[]
pos
;
delete
[]
pos
;
delete
[]
gate
;
delete
[]
gate
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment