nixl.rs 4.55 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Ryan Olson's avatar
Ryan Olson committed
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use super::*;

use anyhow::Result;
7
use nixl_sys::{MemoryRegion, NixlDescriptor, XferDescList, XferStatus};
8
use std::future::Future;
Ryan Olson's avatar
Ryan Olson committed
9

10
fn append_xfer_request<Source, Destination>(
Ryan Olson's avatar
Ryan Olson committed
11
    src: &Source,
12
13
14
15
    dst: &mut Destination,
    src_dl: &mut XferDescList,
    dst_dl: &mut XferDescList,
) -> Result<()>
Ryan Olson's avatar
Ryan Olson committed
16
17
where
    Source: BlockDataProvider,
Ryan Olson's avatar
Ryan Olson committed
18
    Source::StorageType: NixlDescriptor,
Ryan Olson's avatar
Ryan Olson committed
19
    Destination: BlockDataProviderMut,
Ryan Olson's avatar
Ryan Olson committed
20
    Destination::StorageType: NixlDescriptor,
Ryan Olson's avatar
Ryan Olson committed
21
{
Ryan Olson's avatar
Ryan Olson committed
22
23
    let src_data = src.block_data();
    let dst_data = dst.block_data_mut();
Ryan Olson's avatar
Ryan Olson committed
24
25
26
27
28
29
30
31
32
33

    if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
        let src_desc = src_data.block_view()?.as_nixl_descriptor();
        let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();

        unsafe {
            src_dl.add_desc(
                src_desc.as_ptr() as usize,
                src_desc.size(),
                src_desc.device_id(),
34
            );
Ryan Olson's avatar
Ryan Olson committed
35
36
37
38
39

            dst_dl.add_desc(
                dst_desc.as_ptr() as usize,
                dst_desc.size(),
                dst_desc.device_id(),
40
            );
Ryan Olson's avatar
Ryan Olson committed
41
42
        }

43
        Ok(())
Ryan Olson's avatar
Ryan Olson committed
44
45
    } else {
        assert_eq!(src_data.num_layers(), dst_data.num_layers());
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        for layer_idx in 0..src_data.num_layers() {
            for outer_idx in 0..src_data.num_outer_dims() {
                let src_view = src_data.layer_view(layer_idx, outer_idx)?;
                let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;

                debug_assert_eq!(src_view.size(), dst_view.size());

                let src_desc = src_view.as_nixl_descriptor();
                let dst_desc = dst_view.as_nixl_descriptor_mut();

                unsafe {
                    src_dl.add_desc(
                        src_desc.as_ptr() as usize,
                        src_desc.size(),
                        src_desc.device_id(),
61
                    );
62
63
64
65
66

                    dst_dl.add_desc(
                        dst_desc.as_ptr() as usize,
                        dst_desc.size(),
                        dst_desc.device_id(),
67
                    );
68
69
70
71
                }
            }
        }
        Ok(())
Ryan Olson's avatar
Ryan Olson committed
72
73
74
    }
}

75
76
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_blocks_to<Source, Destination>(
Ryan Olson's avatar
Ryan Olson committed
77
    src: &[Source],
78
    dst: &mut [Destination],
79
    ctx: &Arc<TransferContext>,
80
    transfer_type: NixlTransfer,
81
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
Ryan Olson's avatar
Ryan Olson committed
82
83
where
    Source: BlockDataProvider,
Ryan Olson's avatar
Ryan Olson committed
84
    Source::StorageType: NixlDescriptor,
Ryan Olson's avatar
Ryan Olson committed
85
    Destination: BlockDataProviderMut,
Ryan Olson's avatar
Ryan Olson committed
86
    Destination::StorageType: NixlDescriptor,
Ryan Olson's avatar
Ryan Olson committed
87
{
88
89
90
91
    if src.is_empty() || dst.is_empty() {
        return Ok(Box::new(std::future::ready(())));
    }
    assert_eq!(src.len(), dst.len());
Ryan Olson's avatar
Ryan Olson committed
92

93
94
95
96
97
    let nixl_agent_arc = ctx.as_ref().nixl_agent();
    let nixl_agent = nixl_agent_arc
        .as_ref()
        .as_ref()
        .expect("NIXL agent not found");
Ryan Olson's avatar
Ryan Olson committed
98

99
100
101
    let src_mem_type = src
        .first()
        .unwrap()
Ryan Olson's avatar
Ryan Olson committed
102
        .block_data()
103
104
105
106
107
        .storage_type()
        .nixl_mem_type();
    let dst_mem_type = dst
        .first()
        .unwrap()
Ryan Olson's avatar
Ryan Olson committed
108
        .block_data()
109
110
111
        .storage_type()
        .nixl_mem_type();

112
113
    let mut src_dl = XferDescList::new(src_mem_type)?;
    let mut dst_dl = XferDescList::new(dst_mem_type)?;
114
115
116

    for (src, dst) in src.iter().zip(dst.iter_mut()) {
        append_xfer_request(src, dst, &mut src_dl, &mut dst_dl)?;
Ryan Olson's avatar
Ryan Olson committed
117
118
    }

119
120
121
122
123
124
125
    let xfer_req = nixl_agent.create_xfer_req(
        transfer_type.as_xfer_op(),
        &src_dl,
        &dst_dl,
        &nixl_agent.name(),
        None,
    )?;
126

127
128
129
130
131
132
133
134
135
136
137
    let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?;

    if still_pending {
        Ok(Box::new(Box::pin(async move {
            let nixl_agent = nixl_agent_arc
                .as_ref()
                .as_ref()
                .expect("NIXL agent not found");

            loop {
                match nixl_agent.get_xfer_status(&xfer_req) {
138
139
140
141
                    Ok(XferStatus::Success) => break, // Transfer is complete.
                    Ok(XferStatus::InProgress) => {
                        tokio::time::sleep(std::time::Duration::from_millis(5)).await
                    } // Transfer is still in progress.
142
143
144
145
146
147
148
149
150
                    Err(e) => {
                        tracing::error!("Error getting transfer status: {}", e);
                        break;
                    }
                }
            }
        })))
    } else {
        Ok(Box::new(std::future::ready(())))
Ryan Olson's avatar
Ryan Olson committed
151
152
    }
}