Unverified Commit fa7776fd authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

gpt-oss (#11672)



* bf16

* tests

* gpt-oss

* enable gptoss for engine

* rough estimate

* convert to mxfp4

* handle safetensors U8

* clamp glu/linear

* update tokenizer

* MXFP4 support

This implements the Open Compute Microscaling (MX) FP4 format
as a tensor type with backend implementations focusing
on mulmat and mulmatid on CPU, CUDA, and Metal.

* Unit tests for MXFP4 support

This exercises various operations and shapes on both CPU and GPU (if detected
on the system)

* cuda graph

* unit test adjustments

* cuda: optimize memory access

Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4

* mac: fix crash on old macos versions

cblas_sgemm is only supported on v13.3 and up, however bf16 is
only supported on v14+ so we were falling back to ggml-blas and
crashing on bf16 tensors.  Checking for the function being null
seems to be the simplest way to condittionally avoid registering the
backend.

* server: Minimum context length for gptoss

This model requires a minimum context length of 8192 to function
effectively. Users can set higher values through all normal mechanisms
but lower values will be silently reset.

* ggml: Multiply by numParallel for gptoss sliding window

When computing the graph size estimate, the context size is already
multiplied by numParallel so estimates reflect that. However, since
sliding window models use a smaller, fixed context size, they need
to manually take numParallel into account.

* gpt-oss integration

includes harmony parser and thinking levels, etc.

* fix sync

* fix tests

* fix lint

---------
Co-authored-by: default avatarDaniel Hiltgen <daniel@ollama.com>
Co-authored-by: default avatarJesse Gross <jesse@ollama.com>
Co-authored-by: default avatarDevon Rifkin <drifkin@drifkin.net>
parent 0d38b665
...@@ -15,3 +15,26 @@ func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor { ...@@ -15,3 +15,26 @@ func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
return t return t
} }
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
t = m.Weight.MulmatID(ctx, t, indices)
if m.Bias != nil {
var bias ml.Tensor
if len(indices.Shape()) > 1 {
// FIXME: Rows does not support 2D indices for a 2D input tensor so reshape indices to 1D.
bias = m.Bias.Rows(ctx, indices.Contiguous(ctx, indices.Dim(0)*indices.Dim(1))).
Duplicate(ctx).
Reshape(ctx, m.Bias.Dim(0), indices.Dim(0), indices.Dim(1))
} else {
bias = m.Bias.Rows(ctx, indices)
}
t = t.Add(ctx, bias)
}
return t
}
...@@ -4,9 +4,15 @@ import "github.com/ollama/ollama/ml" ...@@ -4,9 +4,15 @@ import "github.com/ollama/ollama/ml"
// Options contains optional parameters for RoPE function // Options contains optional parameters for RoPE function
type Options struct { type Options struct {
OriginalContextLength int
Type int Type int
Factors ml.Tensor Factors ml.Tensor
OriginalContextLength int
// YaRN options
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
} }
// WithOriginalContextLength sets a custom context length // WithOriginalContextLength sets a custom context length
...@@ -31,3 +37,15 @@ func WithFactors(factors ml.Tensor) func(*Options) { ...@@ -31,3 +37,15 @@ func WithFactors(factors ml.Tensor) func(*Options) {
} }
} }
} }
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.AttentionFactor = attentionFactor
}
}
...@@ -22,7 +22,7 @@ var _ TextProcessor = (*BytePairEncoding)(nil) ...@@ -22,7 +22,7 @@ var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
return BytePairEncoding{ return BytePairEncoding{
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), pre: regexp2.MustCompile(pre, regexp2.None),
vocab: vocab, vocab: vocab,
} }
} }
......
package gptoss
import (
"cmp"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Transformer struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
// Forward implements model.Model.
func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
one := ctx.Input().FromFloatSlice([]float32{1}, 1)
for i, block := range m.TransformerBlocks {
m.Cache.SetLayer(i)
if c, ok := m.Cache.(*kvcache.WrapperCache); ok {
// Even layers are sliding window attention.
c.SetLayerType(i % 2)
}
var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength,
numExperts,
numExpertsUsed,
originalContextLength int
eps,
ropeBase,
ropeScale float32
}
func (o Options) RoPEOptions() []func(*rope.Options) {
return []func(*rope.Options){
rope.WithTypeNeoX(),
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
// NOTE: ggml sets this implicitly so there's no need to set it here
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
}
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
type TransformerBlock struct {
Attention *AttentionBlock
MLP *MLPBlock
}
func (d *TransformerBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs, one ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
}
hiddenStates = d.MLP.Forward(ctx, hiddenStates, one, opts)
return hiddenStates
}
type AttentionBlock struct {
Norm *nn.RMSNorm `gguf:"attn_norm"`
QKV *nn.Linear `gguf:"attn_qkv"`
Output *nn.Linear `gguf:"attn_out"`
Sinks ml.Tensor `gguf:"attn_sinks"`
}
func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
residual := hiddenStates
hiddenStates = attn.Norm.Forward(ctx, hiddenStates, opts.eps)
qkv := attn.QKV.Forward(ctx, hiddenStates)
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim)
query := qkv.View(ctx,
0,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numHeads, qkv.Stride(1),
batchSize,
)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
key := qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*opts.numHeads,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
value := qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1./math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Concat(ctx, attn.Sinks.Reshape(ctx, 1, 1, opts.numHeads, 1).Repeat(ctx, 1, batchSize), 0)
scores = scores.Softmax(ctx)
scores = scores.Pad(ctx, -1, 0, 0, 0)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
}
type MLPBlock struct {
Norm *nn.RMSNorm `gguf:"ffn_norm"`
Router *nn.Linear `gguf:"ffn_gate_inp"`
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *Options) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
residual := hiddenStates
hiddenStates = mlp.Norm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routingWeights := mlp.Router.Forward(ctx, hiddenStates)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, sequenceLength*batchSize).Rows(ctx, selectedExperts)
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, sequenceLength*batchSize).Softmax(ctx)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, sequenceLength*batchSize)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2))
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
glu := hiddenStates.View(ctx, 0, dimStride...)
glu = glu.Contiguous(ctx)
glu = glu.Clamp(ctx, float32(math.Inf(-1)), 7.0)
glu = glu.QuickGELU(ctx)
linear := hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
linear = linear.Clamp(ctx, -7.0, 7.0)
hiddenStates = glu.Mul(ctx, linear.Add(ctx, one))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3))
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates.Add(ctx, residual)
}
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1.),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWAMemCache(int32(c.Uint("attention.sliding_window")), 4096, m.Shift),
kvcache.NewCausalCache(m.Shift),
)
m.Cache.SetConfig(ml.CacheConfig{CachePadding: 32, PermutedV: true})
return &m, nil
}
func init() {
model.Register("gptoss", New)
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mistral3"
......
...@@ -36,6 +36,7 @@ type ErrorResponse struct { ...@@ -36,6 +36,7 @@ type ErrorResponse struct {
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
Reasoning string `json:"reasoning,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
...@@ -81,6 +82,10 @@ type StreamOptions struct { ...@@ -81,6 +82,10 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage"` IncludeUsage bool `json:"include_usage"`
} }
type Reasoning struct {
Effort *string `json:"effort,omitempty"`
}
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
...@@ -95,6 +100,7 @@ type ChatCompletionRequest struct { ...@@ -95,6 +100,7 @@ type ChatCompletionRequest struct {
TopP *float64 `json:"top_p"` TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"` ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"` Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
} }
type ChatCompletion struct { type ChatCompletion struct {
...@@ -253,7 +259,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { ...@@ -253,7 +259,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []Choice{{ Choices: []Choice{{
Index: 0, Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls}, Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
reason = "tool_calls" reason = "tool_calls"
...@@ -278,10 +284,10 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu ...@@ -278,10 +284,10 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
if toolCallSent { if toolCallSent || len(toolCalls) > 0 {
return &finishReasonToolCalls return &finishReasonToolCalls
} }
return &reason return &reason
...@@ -397,7 +403,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -397,7 +403,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
for _, msg := range r.Messages { for _, msg := range r.Messages {
switch content := msg.Content.(type) { switch content := msg.Content.(type) {
case string: case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content}) messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning})
case []any: case []any:
for _, c := range content { for _, c := range content {
data, ok := c.(map[string]any) data, ok := c.(map[string]any)
...@@ -508,6 +514,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -508,6 +514,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["top_p"] = 1.0 options["top_p"] = 1.0
} }
if r.Reasoning != nil {
options["reasoning"] = *r.Reasoning.Effort
}
var format json.RawMessage var format json.RawMessage
if r.ResponseFormat != nil { if r.ResponseFormat != nil {
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) { switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
...@@ -521,6 +531,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -521,6 +531,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
} }
var think *api.ThinkValue
if r.Reasoning != nil {
think = &api.ThinkValue{
Value: *r.Reasoning.Effort,
}
}
return &api.ChatRequest{ return &api.ChatRequest{
Model: r.Model, Model: r.Model,
Messages: messages, Messages: messages,
...@@ -528,6 +545,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -528,6 +545,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools, Tools: r.Tools,
Think: think,
}, nil }, nil
} }
......
This diff is collapsed.
This diff is collapsed.
...@@ -111,7 +111,8 @@ func (m *Model) Capabilities() []model.Capability { ...@@ -111,7 +111,8 @@ func (m *Model) Capabilities() []model.Capability {
// Check for thinking capability // Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template) openingTag, closingTag := thinking.InferTags(m.Template.Template)
if openingTag != "" && closingTag != "" { hasTags := openingTag != "" && closingTag != ""
if hasTags || m.Config.ModelFamily == "gptoss" {
capabilities = append(capabilities, model.CapabilityThinking) capabilities = append(capabilities, model.CapabilityThinking)
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment