Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5dd5b531
Commit
5dd5b531
authored
Dec 18, 2024
by
ThomasNing
Browse files
Start to add the mscclpp to the reduce_receive_kernel
parent
9d6e47f0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
170 additions
and
163 deletions
+170
-163
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
+6
-149
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.hpp
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.hpp
+0
-2
include/ck_tile/ops/cross_gpu_reduce.hpp
include/ck_tile/ops/cross_gpu_reduce.hpp
+1
-0
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp
...ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp
+134
-0
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_receive_kernel.hpp
.../ops/cross_gpu_reduce/kernel/cross_gpu_receive_kernel.hpp
+15
-7
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_reduce_tile_partitioner.hpp
...s_gpu_reduce/kernel/cross_gpu_reduce_tile_partitioner.hpp
+9
-2
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_send_kernel.hpp
...ile/ops/cross_gpu_reduce/kernel/cross_gpu_send_kernel.hpp
+5
-3
No files found.
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
View file @
5dd5b531
...
...
@@ -6,142 +6,13 @@
#include <iostream>
#include <string>
#include <thread>
#include <future>
#include <vector>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wsuggest-destructor-override"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wshadow-field-in-constructor"
#pragma clang diagnostic ignored "-Wdocumentation"
#pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override"
#pragma clang diagnostic ignored "-Wcast-align"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wdeprecated-copy-with-user-provided-dtor"
#include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/semaphore.hpp>
#pragma clang diagnostic pop
#include "cross_gpu_reduce.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/cross_gpu_reduce.hpp"
template
<
class
T
>
using
DeviceHandle
=
mscclpp
::
DeviceHandle
<
T
>
;
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constSlaveSmChannels
[
8
];
// For SmChannel
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constMasterSmChannel
;
void
setupConnection
(
int
rank
,
int
slaveRank
,
int
worldSize
,
void
*
dst_data
,
size_t
dataSize
)
{
// Initialize MSCCL++ Communicator
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
TcpBootstrap
>
(
rank
,
worldSize
);
mscclpp
::
Communicator
comm
(
bootstrap
);
mscclpp
::
Transport
transport
=
mscclpp
::
Transport
::
CudaIpc
;
// We'll register our local memory. For the slave, this might be the destination buffer.
// For senders, this might be the source buffer or a local buffer we expose to the slave.
mscclpp
::
RegisteredMemory
localMemory
=
comm
.
registerMemory
(
dst_data
,
dataSize
,
transport
);
if
(
rank
==
slaveRank
)
{
std
::
vector
<
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
connectionFutures
;
std
::
vector
<
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>>
remoteMemFutures
;
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slave_semaphore_list
(
worldSize
);
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
{
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
continue
;
connectionFutures
.
push_back
(
comm
.
connectOnSetup
(
senderRank
,
0
,
transport
));
comm
.
sendMemoryOnSetup
(
localMemory
,
senderRank
,
0
);
remoteMemFutures
.
push_back
(
comm
.
recvMemoryOnSetup
(
senderRank
,
0
));
}
comm
.
setup
();
// Now retrieve all completed futures
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connections
;
connections
.
reserve
(
connectionFutures
.
size
());
for
(
auto
&
cf
:
connectionFutures
)
{
connections
.
push_back
(
cf
.
get
());
}
std
::
vector
<
mscclpp
::
RegisteredMemory
>
remoteMemories
;
remoteMemories
.
reserve
(
remoteMemFutures
.
size
());
for
(
auto
&
rmf
:
remoteMemFutures
)
{
remoteMemories
.
push_back
(
rmf
.
get
());
}
// Create semaphores and channels
// One semaphore per connection
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slaveSemaphores
;
slaveSemaphores
.
reserve
(
connections
.
size
());
for
(
auto
&
conn
:
connections
)
{
slaveSemaphores
.
push_back
(
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
conn
));
}
// Create channels
std
::
vector
<
DeviceHandle
<
mscclpp
::
SmChannel
>>
SmChannels
;
SmChannels
.
reserve
(
slaveSemaphores
.
size
());
for
(
size_t
i
=
0
;
i
<
slaveSemaphores
.
size
();
++
i
)
{
SmChannels
.
push_back
(
mscclpp
::
deviceHandle
(
mscclpp
::
SmChannel
(
slaveSemaphores
[
i
],
remoteMemories
[
i
],
// Remote buffer from the sender
dst_data
// Local buffer (this slave's buffer)
)));
}
hipError_t
error_slave
=
hipMemcpyToSymbol
(
constSlaveSmChannels
,
SmChannels
.
data
(),
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
)
*
SmChannels
.
size
());
if
(
error_slave
!=
hipSuccess
)
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
}
}
else
{
// This is a sender:
// We only connect to the slave, send our memory handle, and receive the slave's memory
// handle.
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connectionFuture
=
comm
.
connectOnSetup
(
slaveRank
,
0
,
transport
);
// Send our memory to the slave
comm
.
sendMemoryOnSetup
(
localMemory
,
slaveRank
,
0
);
// Receive slave's memory
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>
remoteMemoryFuture
=
comm
.
recvMemoryOnSetup
(
slaveRank
,
0
);
comm
.
setup
();
std
::
shared_ptr
<
mscclpp
::
Connection
>
connection
=
connectionFuture
.
get
();
mscclpp
::
RegisteredMemory
remoteMemory
=
remoteMemoryFuture
.
get
();
auto
senderSemaphore
=
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
);
auto
senderChannel
=
mscclpp
::
SmChannel
(
senderSemaphore
,
localMemory
,
remoteMemory
.
data
());
DeviceHandle
<
mscclpp
::
SmChannel
>
senderSmChannel
=
mscclpp
::
deviceHandle
(
senderChannel
);
hipError_t
error_master
=
hipMemcpyToSymbol
(
constMasterSmChannel
,
&
senderSmChannel
,
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
));
if
(
error_master
!=
hipSuccess
)
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
struct
AllocateAndTransferFunctor
...
...
@@ -151,9 +22,8 @@ struct AllocateAndTransferFunctor
ck_tile
::
index_t
host_gpu
,
int
device_id
,
const
ck_tile
::
ArgParser
&
arg_parser
,
const
ck_tile
::
stream_config
&
s
,
std
::
promise
<
const
void
*>&
host_receive_ptr_promise
,
std
::
future
<
const
void
*>&
host_receive_ptr_future
)
const
ck_tile
::
stream_config
&
s
)
{
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"M"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"N"
);
...
...
@@ -218,7 +88,6 @@ struct AllocateAndTransferFunctor
// initialize the receive data buffer and global memory location.
ck_tile
::
HostTensor
<
InputType
>
receive_host
({
M
,
N
});
ck_tile
::
DeviceMem
receive_buf
(
receive_host
.
get_element_space_size_in_bytes
());
args_receive
.
p_receive
=
receive_buf
.
GetDeviceBuffer
();
// initialize the output data buffer.
std
::
string
output_type
=
arg_parser
.
get_str
(
"output_type"
);
if
(
output_type
.
compare
(
"float"
)
==
0
)
...
...
@@ -226,9 +95,7 @@ struct AllocateAndTransferFunctor
ck_tile
::
HostTensor
<
OutputType
>
output_host
({
M
,
N
});
ck_tile
::
DeviceMem
output_buf
(
output_host
.
get_element_space_size_in_bytes
());
args_receive
.
p_output
=
output_buf
.
GetDeviceBuffer
();
host_receive_ptr_promise
.
set_value
(
args_receive
.
p_receive
);
auto
kargs_slave
=
SlaveKernel
::
MakeKargs
(
args_receive
.
p_reduce
,
args_receive
.
p_receive
,
args_receive
.
p_output
,
args_receive
.
M
,
args_receive
.
N
);
...
...
@@ -246,10 +113,8 @@ struct AllocateAndTransferFunctor
}
else
{
const
void
*
send_location_ptr
=
host_receive_ptr_future
.
get
();
args_send
.
p_send
=
send_location_ptr
;
auto
kargs_master
=
MasterKernel
::
MakeKargs
(
args_send
.
p_reduce
,
args_send
.
p_send
,
args_send
.
M
,
args_send
.
N
);
args_send
.
p_reduce
,
args_send
.
M
,
args_send
.
N
);
const
dim3
grids_master
=
MasterKernel
::
GridSize
(
M
,
N
);
ave_time
=
ck_tile
::
launch_kernel
(
s
,
...
...
@@ -268,9 +133,7 @@ struct AllocateAndTransferFunctor
ck_tile
::
HostTensor
<
InputType
>&
host_tensor
,
ck_tile
::
DeviceMem
&
device_mem
,
ck_tile
::
index_t
host_gpu
,
const
ck_tile
::
ArgParser
&
arg_parser
,
std
::
promise
<
const
void
*>&
host_receive_ptr_promise
,
std
::
future
<
const
void
*>&
host_receive_ptr_future
)
const
ck_tile
::
ArgParser
&
arg_parser
)
{
hipError_t
hip_err_set_device
=
hipSetDevice
(
device_id
);
if
(
hip_err_set_device
!=
hipSuccess
)
...
...
@@ -298,9 +161,7 @@ struct AllocateAndTransferFunctor
host_gpu
,
device_id
,
arg_parser
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
},
host_receive_ptr_promise
,
host_receive_ptr_future
);
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
}
};
...
...
@@ -438,8 +299,6 @@ bool run_cross_gpu_reduce(ck_tile::ArgParser arg_parser)
}
}
std
::
promise
<
const
void
*>
host_receive_ptr_promise
;
std
::
future
<
const
void
*>
host_receive_ptr_future
=
host_receive_ptr_promise
.
get_future
();
for
(
int
i
=
0
;
i
<
gpu_nums
;
++
i
)
{
...
...
@@ -448,9 +307,7 @@ bool run_cross_gpu_reduce(ck_tile::ArgParser arg_parser)
std
::
ref
(
transfer_tensor_host_list
[
i
]),
std
::
ref
(
transfer_bufs
[
i
]),
host_gpu
,
arg_parser
,
std
::
ref
(
host_receive_ptr_promise
),
std
::
ref
(
host_receive_ptr_future
));
arg_parser
);
}
// Wait for all threads to complete
...
...
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.hpp
View file @
5dd5b531
...
...
@@ -8,7 +8,6 @@
struct
transfer_receive_basic_args
{
const
void
*
p_reduce
;
const
void
*
p_receive
;
const
void
*
p_output
;
ck_tile
::
index_t
host_gpu
;
ck_tile
::
index_t
device_id
;
...
...
@@ -19,7 +18,6 @@ struct transfer_receive_basic_args
struct
transfer_send_basic_args
{
const
void
*
p_reduce
;
const
void
*
p_send
;
ck_tile
::
index_t
host_gpu
;
ck_tile
::
index_t
device_id
;
ck_tile
::
index_t
M
;
...
...
include/ck_tile/ops/cross_gpu_reduce.hpp
View file @
5dd5b531
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_send_kernel.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_reduce_shape.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_reduce_tile_partitioner.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_receive_pipeline_scale_up.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_receive_pipeline_default_policy.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_send_pipeline_scale_up.hpp"
...
...
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp
0 → 100644
View file @
5dd5b531
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wsuggest-destructor-override"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wshadow-field-in-constructor"
#pragma clang diagnostic ignored "-Wdocumentation"
#pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override"
#pragma clang diagnostic ignored "-Wcast-align"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wdeprecated-copy-with-user-provided-dtor"
#include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/semaphore.hpp>
template
<
class
T
>
using
DeviceHandle
=
mscclpp
::
DeviceHandle
<
T
>
;
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constSlaveSmChannels
[
8
];
// For SmChannel
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constMasterSmChannel
;
void
setupConnection
(
int
rank
,
int
slaveRank
,
int
worldSize
,
void
*
dst_data
,
size_t
dataSize
)
{
// Initialize MSCCL++ Communicator
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
TcpBootstrap
>
(
rank
,
worldSize
);
mscclpp
::
Communicator
comm
(
bootstrap
);
mscclpp
::
Transport
transport
=
mscclpp
::
Transport
::
CudaIpc
;
// We'll register our local memory. For the slave, this might be the destination buffer.
// For senders, this might be the source buffer or a local buffer we expose to the slave.
mscclpp
::
RegisteredMemory
localMemory
=
comm
.
registerMemory
(
dst_data
,
dataSize
,
transport
);
if
(
rank
==
slaveRank
)
{
std
::
vector
<
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
connectionFutures
;
std
::
vector
<
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>>
remoteMemFutures
;
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slave_semaphore_list
(
worldSize
);
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
{
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
continue
;
connectionFutures
.
push_back
(
comm
.
connectOnSetup
(
senderRank
,
0
,
transport
));
comm
.
sendMemoryOnSetup
(
localMemory
,
senderRank
,
0
);
remoteMemFutures
.
push_back
(
comm
.
recvMemoryOnSetup
(
senderRank
,
0
));
}
comm
.
setup
();
// Now retrieve all completed futures
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connections
;
connections
.
reserve
(
connectionFutures
.
size
());
for
(
auto
&
cf
:
connectionFutures
)
{
connections
.
push_back
(
cf
.
get
());
}
std
::
vector
<
mscclpp
::
RegisteredMemory
>
remoteMemories
;
remoteMemories
.
reserve
(
remoteMemFutures
.
size
());
for
(
auto
&
rmf
:
remoteMemFutures
)
{
remoteMemories
.
push_back
(
rmf
.
get
());
}
// Create semaphores and channels
// One semaphore per connection
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slaveSemaphores
;
slaveSemaphores
.
reserve
(
connections
.
size
());
for
(
auto
&
conn
:
connections
)
{
slaveSemaphores
.
push_back
(
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
conn
));
}
// Create channels
std
::
vector
<
DeviceHandle
<
mscclpp
::
SmChannel
>>
SmChannels
;
SmChannels
.
reserve
(
slaveSemaphores
.
size
());
for
(
size_t
i
=
0
;
i
<
slaveSemaphores
.
size
();
++
i
)
{
SmChannels
.
push_back
(
mscclpp
::
deviceHandle
(
mscclpp
::
SmChannel
(
slaveSemaphores
[
i
],
remoteMemories
[
i
],
// Remote buffer from the sender
dst_data
// Local buffer (this slave's buffer)
)));
}
hipError_t
error_slave
=
hipMemcpyToSymbol
(
constSlaveSmChannels
,
SmChannels
.
data
(),
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
)
*
SmChannels
.
size
());
if
(
error_slave
!=
hipSuccess
)
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
}
}
else
{
// This is a sender:
// We only connect to the slave, send our memory handle, and receive the slave's memory
// handle.
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connectionFuture
=
comm
.
connectOnSetup
(
slaveRank
,
0
,
transport
);
// Send our memory to the slave
comm
.
sendMemoryOnSetup
(
localMemory
,
slaveRank
,
0
);
// Receive slave's memory
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>
remoteMemoryFuture
=
comm
.
recvMemoryOnSetup
(
slaveRank
,
0
);
comm
.
setup
();
std
::
shared_ptr
<
mscclpp
::
Connection
>
connection
=
connectionFuture
.
get
();
mscclpp
::
RegisteredMemory
remoteMemory
=
remoteMemoryFuture
.
get
();
auto
senderSemaphore
=
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
);
auto
senderChannel
=
mscclpp
::
SmChannel
(
senderSemaphore
,
localMemory
,
remoteMemory
.
data
());
DeviceHandle
<
mscclpp
::
SmChannel
>
senderSmChannel
=
mscclpp
::
deviceHandle
(
senderChannel
);
hipError_t
error_master
=
hipMemcpyToSymbol
(
constMasterSmChannel
,
&
senderSmChannel
,
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
));
if
(
error_master
!=
hipSuccess
)
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
}
}
}
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_receive_kernel.hpp
View file @
5dd5b531
...
...
@@ -4,6 +4,9 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp"
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constSlaveSmChannels
[
8
];
// For SmChannel
namespace
ck_tile
{
template
<
typename
CrossReducePartitioner
,
typename
ReduceReceivePipeline_
>
...
...
@@ -17,19 +20,17 @@ struct ReduceReceiveKernel
struct
ReduceReceiveKargs
{
const
void
*
reduce_ptr
;
const
void
*
receive_ptr
;
const
void
*
output_ptr
;
index_t
M
;
index_t
N
;
};
CK_TILE_HOST
static
constexpr
ReduceReceiveKargs
MakeKargs
(
const
void
*
reduce_ptr
,
const
void
*
receive_ptr
,
const
void
*
output_ptr
,
index_t
M
,
index_t
N
)
{
return
ReduceReceiveKargs
{
reduce_ptr
,
receive_ptr
,
output_ptr
,
M
,
N
};
return
ReduceReceiveKargs
{
reduce_ptr
,
output_ptr
,
M
,
N
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -44,8 +45,9 @@ struct ReduceReceiveKernel
CK_TILE_DEVICE
void
operator
()(
ReduceReceiveKargs
kargs
)
const
{
const
auto
i_M
=
CrossReducePartitioner
{}();
const
DataType
*
reduce_start
=
static_cast
<
const
DataType
*>
(
kargs
.
reduce_ptr
);
auto
channel
=
*
constSlaveSmChannels
[
0
];
const
auto
[
i_m
,
i_n
]
=
CrossReducePartitioner
{}();
const
DataType
*
reduce_start
=
static_cast
<
const
DataType
*>
(
reduce_ptr
);
auto
transfer_tensor_view
=
[
&
]()
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reduce_start
,
...
...
@@ -58,7 +60,13 @@ struct ReduceReceiveKernel
make_tile_window
(
transfer_tensor_view
,
make_tuple
(
number
<
ReduceReceivePipeline
::
Block_M
>
{},
number
<
ReduceReceivePipeline
::
Block_N
>
{}),
{
i_M
,
0
});
{
i_m
,
i_n
});
uint32_t
numThreads
=
static_cast
<
uint32_t
>
(
CrossReducePartitioner
::
NumThreads
(
kargs
.
M
,
kargs
.
N
));
uint32_t
threadId
=
static_cast
<
uint32_t
>
(
i_m
+
i_n
*
(
kargs
.
M
+
ReduceReceivePipeline
::
Block_M
-
1
)
/
ReduceReceivePipeline
::
Block_M
);
uint64_t
totalBytes
=
static_cast
<
uint64_t
>
(
ReduceReceivePipeline
::
Block_M
*
ReduceReceivePipeline
::
Block_N
*
sizeof
(
DataType
));
channel
.
get
(
0
,
totalBytes
,
threadId
,
numThreads
);
const
ODataType
*
output_start
=
static_cast
<
const
ODataType
*>
(
kargs
.
output_ptr
);
auto
output_tensor_view
=
[
&
]()
{
...
...
@@ -73,7 +81,7 @@ struct ReduceReceiveKernel
make_tile_window
(
output_tensor_view
,
make_tuple
(
number
<
ReduceReceivePipeline
::
Block_M
>
{},
number
<
ReduceReceivePipeline
::
Block_N
>
{}),
{
i_
M
,
0
});
{
i_
m
,
i_n
});
__shared__
char
smem_ptr
[
ReduceReceivePipeline
::
GetSmemSize
()];
...
...
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_reduce_tile_partitioner.hpp
View file @
5dd5b531
...
...
@@ -14,6 +14,12 @@ struct CrossReducePartitioner
static
constexpr
index_t
kM
=
CrossReduceShape
::
Block_M
;
static
constexpr
index_t
kN
=
CrossReduceShape
::
Block_N
;
CK_TILE_HOST
static
constexpr
auto
NumThreads
(
index_t
M
,
index_t
N
){
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
return
GridDimX
*
GridDimY
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
{
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
...
...
@@ -22,8 +28,9 @@ struct CrossReducePartitioner
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
i_M
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
return
i_M
;
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
return
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_send_kernel.hpp
View file @
5dd5b531
...
...
@@ -4,6 +4,9 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_connect.hpp"
__constant__
mscclpp
::
DeviceHandle
<
mscclpp
::
SmChannel
>
constMasterSmChannel
;
namespace
ck_tile
{
template
<
typename
CrossReducePartitioner
,
typename
ReduceSendPipeline_
>
...
...
@@ -15,15 +18,14 @@ struct ReduceSendKernel
struct
ReduceSendKargs
{
const
void
*
reduce_ptr
;
const
void
*
send_ptr
;
index_t
M
;
index_t
N
;
};
CK_TILE_HOST
static
constexpr
ReduceSendKargs
MakeKargs
(
const
void
*
reduce_ptr
,
const
void
*
send_ptr
,
index_t
M
,
index_t
N
)
MakeKargs
(
const
void
*
reduce_ptr
,
index_t
M
,
index_t
N
)
{
return
ReduceSendKargs
{
reduce_ptr
,
send_ptr
,
M
,
N
};
return
ReduceSendKargs
{
reduce_ptr
,
M
,
N
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment