nixl.rs 5.01 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use super::*;

use anyhow::Result;
19
20
use nixl_sys::{MemoryRegion, NixlDescriptor, XferDescList};
use std::future::Future;
Ryan Olson's avatar
Ryan Olson committed
21

22
23
24
25
26
27
fn append_xfer_request<Source, Destination>(
    src: &Arc<Source>,
    dst: &mut Destination,
    src_dl: &mut XferDescList,
    dst_dl: &mut XferDescList,
) -> Result<()>
Ryan Olson's avatar
Ryan Olson committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
where
    Source: BlockDataProvider,
    Destination: BlockDataProviderMut,
{
    let src_data = src.block_data(private::PrivateToken);
    let dst_data = dst.block_data_mut(private::PrivateToken);

    if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
        let src_desc = src_data.block_view()?.as_nixl_descriptor();
        let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();

        unsafe {
            src_dl.add_desc(
                src_desc.as_ptr() as usize,
                src_desc.size(),
                src_desc.device_id(),
            )?;

            dst_dl.add_desc(
                dst_desc.as_ptr() as usize,
                dst_desc.size(),
                dst_desc.device_id(),
            )?;
        }

53
        Ok(())
Ryan Olson's avatar
Ryan Olson committed
54
55
    } else {
        assert_eq!(src_data.num_layers(), dst_data.num_layers());
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        for layer_idx in 0..src_data.num_layers() {
            for outer_idx in 0..src_data.num_outer_dims() {
                let src_view = src_data.layer_view(layer_idx, outer_idx)?;
                let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;

                debug_assert_eq!(src_view.size(), dst_view.size());

                let src_desc = src_view.as_nixl_descriptor();
                let dst_desc = dst_view.as_nixl_descriptor_mut();

                unsafe {
                    src_dl.add_desc(
                        src_desc.as_ptr() as usize,
                        src_desc.size(),
                        src_desc.device_id(),
                    )?;

                    dst_dl.add_desc(
                        dst_desc.as_ptr() as usize,
                        dst_desc.size(),
                        dst_desc.device_id(),
                    )?;
                }
            }
        }
        Ok(())
Ryan Olson's avatar
Ryan Olson committed
82
83
84
    }
}

85
86
87
88
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_blocks_to<Source, Destination>(
    src: &[Arc<Source>],
    dst: &mut [Destination],
89
    ctx: &Arc<TransferContext>,
90
    transfer_type: NixlTransfer,
91
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
Ryan Olson's avatar
Ryan Olson committed
92
93
94
95
where
    Source: BlockDataProvider,
    Destination: BlockDataProviderMut,
{
96
97
98
99
    if src.is_empty() || dst.is_empty() {
        return Ok(Box::new(std::future::ready(())));
    }
    assert_eq!(src.len(), dst.len());
Ryan Olson's avatar
Ryan Olson committed
100

101
102
103
104
105
    let nixl_agent_arc = ctx.as_ref().nixl_agent();
    let nixl_agent = nixl_agent_arc
        .as_ref()
        .as_ref()
        .expect("NIXL agent not found");
Ryan Olson's avatar
Ryan Olson committed
106

107
108
109
110
111
112
113
114
115
116
117
118
119
    let src_mem_type = src
        .first()
        .unwrap()
        .block_data(private::PrivateToken)
        .storage_type()
        .nixl_mem_type();
    let dst_mem_type = dst
        .first()
        .unwrap()
        .block_data(private::PrivateToken)
        .storage_type()
        .nixl_mem_type();

jthomson04's avatar
jthomson04 committed
120
121
    let mut src_dl = XferDescList::new(src_mem_type, true)?;
    let mut dst_dl = XferDescList::new(dst_mem_type, true)?;
122
123
124

    for (src, dst) in src.iter().zip(dst.iter_mut()) {
        append_xfer_request(src, dst, &mut src_dl, &mut dst_dl)?;
Ryan Olson's avatar
Ryan Olson committed
125
126
    }

127
128
    debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?);

129
130
131
132
133
134
135
    let xfer_req = nixl_agent.create_xfer_req(
        transfer_type.as_xfer_op(),
        &src_dl,
        &dst_dl,
        &nixl_agent.name(),
        None,
    )?;
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?;

    if still_pending {
        Ok(Box::new(Box::pin(async move {
            let nixl_agent = nixl_agent_arc
                .as_ref()
                .as_ref()
                .expect("NIXL agent not found");

            loop {
                match nixl_agent.get_xfer_status(&xfer_req) {
                    Ok(false) => break, // Transfer is complete.
                    Ok(true) => tokio::time::sleep(std::time::Duration::from_millis(5)).await, // Transfer is still in progress.
                    Err(e) => {
                        tracing::error!("Error getting transfer status: {}", e);
                        break;
                    }
                }
            }
        })))
    } else {
        Ok(Box::new(std::future::ready(())))
Ryan Olson's avatar
Ryan Olson committed
159
160
    }
}