"tools/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9731e023255559b1c7537aad60b0d534738add85"
Unverified Commit 95191ebd authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Migrate weak_ref_tensor to sgl-kernel (#12505)

parent 9a512cf9
...@@ -320,6 +320,7 @@ set(SOURCES ...@@ -320,6 +320,7 @@ set(SOURCES
"csrc/kvcacheio/transfer.cu" "csrc/kvcacheio/transfer.cu"
"csrc/mamba/causal_conv1d.cu" "csrc/mamba/causal_conv1d.cu"
"csrc/memory/store.cu" "csrc/memory/store.cu"
"csrc/memory/weak_ref_tensor.cpp"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
......
...@@ -396,6 +396,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -396,6 +396,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache); m.impl("store_kv_cache", &store_kv_cache);
m.def("weak_ref_tensor(Tensor tensor) -> Tensor");
m.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
/* /*
* From FlashInfer * From FlashInfer
*/ */
......
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <vector>
at::Tensor weak_ref_tensor(const at::Tensor& tensor) {
TORCH_CHECK(tensor.is_cuda(), "weak_ref_tensor expects a CUDA tensor");
void* data_ptr = tensor.data_ptr();
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
auto options = tensor.options();
auto new_tensor = at::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
...@@ -665,6 +665,7 @@ void transfer_kv_all_layer_direct_lf_pf( ...@@ -665,6 +665,7 @@ void transfer_kv_all_layer_direct_lf_pf(
/* /*
* From csrc/memory * From csrc/memory
*/ */
at::Tensor weak_ref_tensor(const at::Tensor& tensor);
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
/* /*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment