Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1b20639a
"torchvision/vscode:/vscode.git/clone" did not exist on "6b1646cae7e4a0118bd769cd9921ff2269561047"
Unverified
Commit
1b20639a
authored
Jan 30, 2024
by
Hanzhi Zhou
Committed by
GitHub
Jan 29, 2024
Browse files
No repeated IPC open (#2642)
parent
b72af8f1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
18 deletions
+25
-18
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+25
-18
No files found.
csrc/custom_all_reduce.cuh
View file @
1b20639a
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <iostream>
#include <iostream>
#include <limits>
#include <limits>
#include <map>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
}
}
}
}
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
static_assert
(
sizeof
(
IPC_KEY
)
==
sizeof
(
cudaIpcMemHandle_t
));
static_assert
(
alignof
(
IPC_KEY
)
==
alignof
(
cudaIpcMemHandle_t
));
class
CustomAllreduce
{
class
CustomAllreduce
{
public:
public:
int
rank_
;
int
rank_
;
...
@@ -341,7 +346,8 @@ class CustomAllreduce {
...
@@ -341,7 +346,8 @@ class CustomAllreduce {
// stores the registered device pointers from all ranks
// stores the registered device pointers from all ranks
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
std
::
vector
<
void
*>
ipc_handles_
;
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
* meta is a pointer to device metadata and temporary buffer for allreduce.
...
@@ -365,10 +371,7 @@ class CustomAllreduce {
...
@@ -365,10 +371,7 @@ class CustomAllreduce {
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Metadata
*
rank_meta
;
Metadata
*
rank_meta
;
if
(
i
!=
rank_
)
{
if
(
i
!=
rank_
)
{
char
*
handle
;
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
handle
,
handles
[
i
],
cudaIpcMemLazyEnablePeerAccess
));
ipc_handles_
.
push_back
(
handle
);
handle
+=
offsets
[
i
];
handle
+=
offsets
[
i
];
rank_meta
=
(
Metadata
*
)
handle
;
rank_meta
=
(
Metadata
*
)
handle
;
}
else
{
}
else
{
...
@@ -378,6 +381,19 @@ class CustomAllreduce {
...
@@ -378,6 +381,19 @@ class CustomAllreduce {
}
}
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
}
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
()
{
get_graph_buffer_ipc_meta
()
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
...
@@ -413,11 +429,7 @@ class CustomAllreduce {
...
@@ -413,11 +429,7 @@ class CustomAllreduce {
RankData
data
;
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
if
(
i
!=
rank_
)
{
if
(
i
!=
rank_
)
{
char
*
handle
;
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
CUDACHECK
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
handle
,
*
((
const
cudaIpcMemHandle_t
*
)
handles
[
i
].
data
()),
cudaIpcMemLazyEnablePeerAccess
));
ipc_handles_
.
push_back
(
handle
);
handle
+=
offsets
[
i
];
handle
+=
offsets
[
i
];
data
.
ptrs
[
i
]
=
handle
;
data
.
ptrs
[
i
]
=
handle
;
}
else
{
}
else
{
...
@@ -448,13 +460,8 @@ class CustomAllreduce {
...
@@ -448,13 +460,8 @@ class CustomAllreduce {
auto
&
rd
=
rank_data
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
if
(
j
!=
rank_
)
{
if
(
j
!=
rank_
)
{
char
*
handle
;
char
*
handle
=
CUDACHECK
(
cudaIpcOpenMemHandle
(
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
(
void
**
)
&
handle
,
*
((
cudaIpcMemHandle_t
*
)
&
handles
[
j
]
[
i
*
sizeof
(
cudaIpcMemHandle_t
)]),
cudaIpcMemLazyEnablePeerAccess
));
ipc_handles_
.
push_back
(
handle
);
handle
+=
offsets
[
j
][
i
];
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
rd
.
ptrs
[
j
]
=
handle
;
}
else
{
}
else
{
...
@@ -541,7 +548,7 @@ class CustomAllreduce {
...
@@ -541,7 +548,7 @@ class CustomAllreduce {
}
}
~
CustomAllreduce
()
{
~
CustomAllreduce
()
{
for
(
auto
ptr
:
ipc_handles_
)
{
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CUDACHECK
(
cudaIpcCloseMemHandle
(
ptr
));
CUDACHECK
(
cudaIpcCloseMemHandle
(
ptr
));
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment