nixl.rs 6.17 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use super::*;

use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
20
use std::future::{poll_fn, Future};
Ryan Olson's avatar
Ryan Olson committed
21
use std::ops::Range;
22
use std::task::Poll;
Ryan Olson's avatar
Ryan Olson committed
23
24
25
26
27

/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>(
    src: &'a Source,
    dst: &'a mut Destination,
28
    ctx: Arc<TransferContext>,
Ryan Olson's avatar
Ryan Olson committed
29
    notify: Option<String>,
30
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
Ryan Olson's avatar
Ryan Olson committed
31
32
33
34
35
36
37
38
where
    Source: BlockDataProvider,
    Destination: BlockDataProviderMut,
{
    let src_data = src.block_data(private::PrivateToken);
    let dst_data = dst.block_data_mut(private::PrivateToken);

    if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
39
40
41
42
43
44
45
        // Keep the arc to use in the returned future.
        let nixl_agent_arc = ctx.as_ref().nixl_agent();

        let nixl_agent = nixl_agent_arc
            .as_ref()
            .as_ref()
            .expect("NIXL agent not found");
Ryan Olson's avatar
Ryan Olson committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
        let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;

        let src_desc = src_data.block_view()?.as_nixl_descriptor();
        let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();

        unsafe {
            src_dl.add_desc(
                src_desc.as_ptr() as usize,
                src_desc.size(),
                src_desc.device_id(),
            )?;

            dst_dl.add_desc(
                dst_desc.as_ptr() as usize,
                dst_desc.size(),
                dst_desc.device_id(),
            )?;
        }

67
68
69
        let xfer_req = nixl_agent
            .create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)
            .unwrap();
Ryan Olson's avatar
Ryan Olson committed
70
71
72
73
74
75
76
77

        let mut xfer_args = OptArgs::new()?;

        if let Some(notify) = notify {
            xfer_args.set_has_notification(true)?;
            xfer_args.set_notification_message(notify.as_bytes())?;
        }

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;

        // Return a future that completes when the transfer is complete.
        // TODO: How efficient is this? Can we do better?
        Ok(Box::new(poll_fn(move |_cx| {
            let nixl_agent = nixl_agent_arc
                .as_ref()
                .as_ref()
                .expect("NIXL agent not found");

            // The nixl agent returns true if the transfer is still in progress.
            if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
                Poll::Ready(())
            } else {
                Poll::Pending
Ryan Olson's avatar
Ryan Olson committed
93
            }
94
        })))
Ryan Olson's avatar
Ryan Olson committed
95
96
    } else {
        assert_eq!(src_data.num_layers(), dst_data.num_layers());
97
        write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)
Ryan Olson's avatar
Ryan Olson committed
98
99
100
101
102
103
104
105
    }
}

/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn write_layers_to<'a, Source, Destination>(
    layer_range: Range<usize>,
    src: &'a Source,
    dst: &'a mut Destination,
106
    ctx: Arc<TransferContext>,
Ryan Olson's avatar
Ryan Olson committed
107
    notify: Option<String>,
108
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
Ryan Olson's avatar
Ryan Olson committed
109
110
111
112
113
114
115
where
    Source: BlockDataProvider,
    Destination: BlockDataProviderMut,
{
    let src_data = src.block_data(private::PrivateToken);
    let dst_data = dst.block_data_mut(private::PrivateToken);

116
117
118
119
120
    let nixl_agent_arc = ctx.as_ref().nixl_agent();
    let nixl_agent = nixl_agent_arc
        .as_ref()
        .as_ref()
        .expect("NIXL agent not found");
Ryan Olson's avatar
Ryan Olson committed
121

122
    let remote_worker_id = dst_data.worker_id.to_string();
Ryan Olson's avatar
Ryan Olson committed
123
124
125
126
127
128
129
130
131
132
133
134
    let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
    let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;

    // #[cfg(debug_assertions)]
    // {
    //     let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
    //         Destination::StorageType,
    //     >>::write_to_strategy();
    //     assert_eq!(strategy, expected_strategy);
    // }

    for layer_idx in layer_range {
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        for outer_idx in 0..src_data.num_outer_dims() {
            let src_view = src_data.layer_view(layer_idx, outer_idx)?;
            let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;

            debug_assert_eq!(src_view.size(), dst_view.size());

            let src_desc = src_view.as_nixl_descriptor();
            let dst_desc = dst_view.as_nixl_descriptor_mut();

            unsafe {
                src_dl.add_desc(
                    src_desc.as_ptr() as usize,
                    src_desc.size(),
                    src_desc.device_id(),
                )?;

                dst_dl.add_desc(
                    dst_desc.as_ptr() as usize,
                    dst_desc.size(),
                    dst_desc.device_id(),
                )?;
            }
Ryan Olson's avatar
Ryan Olson committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        }
    }

    let mut xfer_args = OptArgs::new()?;

    if let Some(notify) = notify {
        xfer_args.set_has_notification(true)?;
        xfer_args.set_notification_message(notify.as_bytes())?;
    }

    let xfer_req = nixl_agent.create_xfer_req(
        XferOp::Write,
        &src_dl,
        &dst_dl,
        &remote_worker_id,
        Some(&xfer_args),
    )?;

175
176
177
178
179
180
181
182
183
184
185
    let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;

    Ok(Box::new(poll_fn(move |_cx| {
        let nixl_agent = nixl_agent_arc
            .as_ref()
            .as_ref()
            .expect("NIXL agent not found");
        if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
            Poll::Ready(())
        } else {
            Poll::Pending
Ryan Olson's avatar
Ryan Olson committed
186
        }
187
    })))
Ryan Olson's avatar
Ryan Olson committed
188
}