Unverified Commit 1188f408 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

s/From*Slice/From*s/ (#12255)

parent 15c7d30d
...@@ -43,7 +43,7 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int ...@@ -43,7 +43,7 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int
} }
} }
mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) mask := ctx.Input().FromFloats(flat, seqLength, seqLength)
// Reshape to match [seqLength, seqLength, 1] for broadcasting // Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1) mask = mask.Reshape(ctx, seqLength, seqLength, 1)
...@@ -299,7 +299,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) ...@@ -299,7 +299,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
} }
} }
t := ctx.Input().FromIntSlice(index, len(index)) t := ctx.Input().FromInts(index, len(index))
return t, bounds return t, bounds
} }
...@@ -319,7 +319,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor ...@@ -319,7 +319,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
} }
} }
freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize)
// Create position coordinates (y,x pairs) for the grid // Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange() // In PyTorch: Equivalent to generating position ids with torch.arange()
...@@ -329,7 +329,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor ...@@ -329,7 +329,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
coords = append(coords, int32(y), int32(x)) coords = append(coords, int32(y), int32(x))
} }
} }
pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height)
// Reshape and permute positions to match spatial merging pattern // Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
......
...@@ -181,7 +181,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { ...@@ -181,7 +181,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Forward implements model.Model. // Forward implements model.Model.
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
......
...@@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten ...@@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
for i, t := range entry.mm { for i, t := range entry.mm {
if in == t.Tensor { if in == t.Tensor {
if !reserve { if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil return ctx.Input().FromFloats(entry.data[i], t.Tensor.Shape()...), nil
} else { } else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
} }
......
...@@ -599,7 +599,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er ...@@ -599,7 +599,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) batch.Outputs = nextBatch.ctx.Input().FromInts(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil { if err != nil {
err = fmt.Errorf("failed to build graph: %w", err) err = fmt.Errorf("failed to build graph: %w", err)
...@@ -692,7 +692,7 @@ func (s *Server) computeBatch(activeBatch batchState) { ...@@ -692,7 +692,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// At this point the seqs are ready for forwardBatch to move forward so unblock // At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock() s.mu.Unlock()
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs) activeBatch.batch.Inputs.FromInts(batchInputs)
activeBatch.ctx.ComputeWithNotify( activeBatch.ctx.ComputeWithNotify(
func() { func() {
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id) logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
...@@ -1090,7 +1090,7 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -1090,7 +1090,7 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i) batch.Positions[i] = int32(i)
} }
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) batch.Inputs = ctx.Input().FromInts(batchInputs, len(batchInputs))
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache cache := s.model.Config().Cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment