"src/turbomind/models/llama/LlamaV2.cc" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
edgewise_spline_weighting.py 319 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
import torch

if torch.cuda.is_available():
    from .edgewise_spline_weighting_gpu import EdgewiseSplineWeightingGPU


def edgewise_spline_weighting(input, weight, amount, index):
    if input.is_cuda:
        return EdgewiseSplineWeightingGPU(amount, index)(input, weight)
    else:
rusty1s's avatar
rusty1s committed
11
        raise NotImplementedError