Unverified Commit d3b4b997 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

app: add code for macOS and Windows apps under 'app' (#12933)



* app: add code for macOS and Windows apps under 'app'

* app: add readme

* app: windows and linux only for now

* ci: fix ui CI validation

---------
Co-authored-by: default avatarjmorganca <jmorganca@gmail.com>
parent a4770107
import { parents, type Proxy } from "unist-util-parents";
import type { Plugin } from "unified";
import type {
Emphasis,
Node,
Parent,
Root,
RootContent,
Text,
Strong,
PhrasingContent,
Paragraph,
} from "mdast";
import { u } from "unist-builder";
declare module "unist" {
interface Node {
/** Added by `unist-util-parents` (or your own walk). */
parent?: Proxy & Parent;
}
}
// interface SimpleTextRule {
// pattern: RegExp;
// transform: (matches: RegExpExecArray[], lastNode: Proxy) => void;
// }
// const simpleTextRules: SimpleTextRule[] = [
// // TODO(drifkin): generalize this for `__`/`_`/`~~`/`~` etc.
// {
// pattern: /(\*\*)(?=\S|$)/g,
// transform: (matchesIterator, lastNode) => {
// const textNode = lastNode.node as Text;
// const matches = [...matchesIterator];
// const lastMatch = matches[matches.length - 1];
// const origValue = textNode.value;
// const start = lastMatch.index;
// const sep = lastMatch[1];
// const before = origValue.slice(0, start);
// const after = origValue.slice(start + sep.length);
// if (lastNode.parent) {
// const index = (lastNode.parent.node as Parent).children.indexOf(
// lastNode.node as RootContent,
// );
// const shouldRemove = before.length === 0;
// if (!shouldRemove) {
// textNode.value = before;
// }
// const newNode = u("strong", {
// children: [u("text", { value: after })],
// });
// (lastNode.parent.node as Parent).children.splice(
// index + (shouldRemove ? 0 : 1),
// shouldRemove ? 1 : 0,
// newNode,
// );
// }
// },
// },
// ];
interface Options {
debug?: boolean;
onLastNode?: (info: LastNodeInfo) => void;
}
export interface LastNodeInfo {
path: string[];
type: string;
value?: string;
lastChars?: string;
fullNode: Node;
}
/**
* Removes `child` from `parent` in-place.
* @returns `true` if the child was found and removed; `false` otherwise.
*/
export function removeChildFromParent(
child: RootContent,
parent: Node,
): boolean {
if (!isParent(parent)) return false; // parent isn’t a Parent → nothing to do
const idx = parent.children.indexOf(child);
if (idx < 0) return false; // not a child → nothing to remove
parent.children.splice(idx, 1);
return true; // removal successful
}
/** Narrow a generic `Node` to a `Parent` (i.e. one that really has children). */
function isParent(node: Node): node is Parent {
// A `Parent` always has a `children` array; make sure it's an array first.
return Array.isArray((node as Partial<Parent>).children);
}
/**
* Follow “last-child” pointers until you reach a leaf.
* Returns the right-most, deepest node in source order.
*/
export function findRightmostDeepestNode(root: Node): Node {
let current: Node = root;
// While the current node *is* a Parent and has at least one child…
while (isParent(current) && current.children.length > 0) {
const lastIndex = current.children.length - 1;
current = current.children[lastIndex];
}
return current; // Leaf: no further children
}
const remarkStreamingMarkdown: Plugin<[Options?], Root> = () => {
return (tree) => {
const treeWithParents = parents(tree);
const lastNode = findRightmostDeepestNode(treeWithParents) as Proxy;
const parentNode = lastNode.parent;
const grandparentNode = parentNode?.parent;
let ruleMatched = false;
// handling `* *` -> ``
//
// if the last node is part of a <list item (otherwise empty)> ->
// <list (otherwise empty)> -> <list item (last node, empty)>, then we need to
// remove everything up to and including the first list item. This happens
// when we have `* *`, which can become a bolded list item OR a horizontal
// line
if (
lastNode.type === "listItem" &&
parentNode &&
grandparentNode &&
parentNode.type === "list" &&
grandparentNode.type === "listItem" &&
parentNode.children.length === 1 &&
grandparentNode.children.length === 1
) {
ruleMatched = true;
if (grandparentNode.parent) {
removeChildFromParent(
grandparentNode.node as RootContent,
grandparentNode.parent.node,
);
}
// Handle `*` -> ``:
//
// if the last node is just an empty list item, we need to remove it
// because it could become something else (e.g., a horizontal line)
} else if (
lastNode.type === "listItem" &&
parentNode &&
parentNode.type === "list"
) {
ruleMatched = true;
removeChildFromParent(lastNode.node as RootContent, parentNode.node);
} else if (lastNode.type === "thematicBreak") {
ruleMatched = true;
const parent = lastNode.parent;
if (parent) {
removeChildFromParent(lastNode.node as RootContent, parent.node);
}
} else if (lastNode.type === "text") {
const textNode = lastNode.node as Text;
if (textNode.value.endsWith("**")) {
ruleMatched = true;
textNode.value = textNode.value.slice(0, -2);
// if there's a newline then a number, this is very very likely a
// numbered list item. Let's just hide it until the period comes (or
// other text disambiguates it)
} else {
const match = textNode.value.match(/^([0-9]+)$/m);
if (match) {
const number = match[1];
textNode.value = textNode.value.slice(0, -number.length - 1);
ruleMatched = true;
// if the text node is now empty, then we might want to remove other
// elements, like a now-empty containing paragraph, or a break that
// might disappear once more tokens come in
if (textNode.value.length === 0) {
if (
lastNode.parent?.type === "paragraph" &&
lastNode.parent.children.length === 1
) {
// remove the whole paragraph if it's now empty (otherwise it'll
// cause an extra newline that might not last)
removeChildFromParent(
lastNode.parent.node as Paragraph,
lastNode.parent.parent?.node as Node,
);
} else {
const prev = prevSibling(lastNode);
if (prev?.type === "break") {
removeChildFromParent(
prev.node as RootContent,
lastNode.parent?.node as Node,
);
removeChildFromParent(
lastNode.node as RootContent,
lastNode.parent?.node as Node,
);
}
}
}
}
}
}
if (ruleMatched) {
return tree;
}
// we need to
// a case like
// - *def `abc` [abc **def**](abc)*
// is pretty tricky, because if we land just after def, then we actually
// have two separate tags to process at two different parents. Maybe we
// need to keep iterating up until we find a paragraph, but process each
// parent on the way up. Hmm, well actually after `def` we won't even be a proper link yet
// TODO(drifkin): it's really if the last node's parent is a paragraph, for which the following is a sub-cas where the lastNode is a text node.
// And instead of just processing simple text rules, they need to operate on the whole paragraph
// like `**[abc](def)` needs to become `**[abc](def)**`
// if we're just text at the end, then we should remove some ambiguous characters
if (lastNode.parent) {
const didChange = processParent(lastNode.parent as Parent & Proxy);
if (didChange) {
// TODO(drifkin): need to fix up the tree, but not sure lastNode will still exist? Check all the transforms to see if it's safe to find the last node again
//
// need to regen the tree w/ parents since reparenting could've happened
// treeWithParents = parents(tree);
}
}
const grandparent = lastNode.parent?.parent;
// TODO(drifkin): let's go arbitrarily high up the tree, but limiting it
// to 2 levels for now until I think more about the stop condition
if (grandparent) {
processParent(grandparent as Parent & Proxy);
}
// console.log("ruleMatched", ruleMatched);
// } else if (lastNode.parent?.type === "paragraph") {
// console.log("!!! paragraph");
// console.log("lastNode.parent", lastNode.parent);
// // Handle `**abc*` -> `**abc**`:
// // We detect this when the last child is an emphasis node, and it's preceded by a text node that ends with `*`
// const paragraph = lastNode.parent as Proxy & Paragraph;
// if (paragraph.children.length >= 2) {
// const lastChild = paragraph.children[paragraph.children.length - 1];
// if (lastChild.type === "emphasis") {
// const sibling = paragraph.children[paragraph.children.length - 2];
// if (sibling.type === "text") {
// const siblingText = sibling as Text & Proxy;
// if (siblingText.value.endsWith("*")) {
// ruleMatched = true;
// const textNode = (lastNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// paragraph.node.type = "strong";
// }
// }
// }
// }
// } else if (lastNode.type === "text") {
// // Handle `**abc*` -> `**abc**`:
// //
// // this gets parsed as a text node ending in `*` followed by an emphasis
// // node. So if we're in text, we need to check if our parent is emphasis,
// // and then get our parent's sibling before it and check if it ends with
// // `*`
// const parent = lastNode.parent;
// if (parent && parent.type === "emphasis") {
// const grandparent = parent.parent;
// if (grandparent) {
// const index = (grandparent.node as Parent).children.indexOf(
// parent.node as RootContent,
// );
// if (index > 0) {
// const prevNode = grandparent.children[index - 1];
// if (
// prevNode.type === "text" &&
// (prevNode as Text).value.endsWith("*")
// ) {
// ruleMatched = true;
// const textNode = (prevNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// parent.node.type = "strong";
// }
// }
// }
// }
// if (!ruleMatched) {
// // if the last node is just text, then we process it in order to fix up certain unclosed items
// // e.g., `**abc` -> `**abc**`
// const textNode = lastNode.node as Text;
// for (const rule of simpleTextRules) {
// const matchesIterator = textNode.value.matchAll(rule.pattern);
// const matches = [...matchesIterator];
// if (matches.length > 0) {
// rule.transform(matches, lastNode);
// ruleMatched = true;
// break;
// }
// }
// }
// } else if (!ruleMatched) {
// // console.log("no rule matched", lastNode);
// }
return tree;
};
};
function processParent(parent: Parent & Proxy): boolean {
if (parent.type === "emphasis") {
// Handle `**abc*` -> `**abc**`:
// We detect this when we end with an emphasis node, and it's preceded by
// a text node that ends with `*`
// TODO(drifkin): the last node can be more deeply nested (e.g., a code
// literal in a link), so we probably need to walk up the tree until we
// find an emphasis node or a block? For now we'll just go up one layer to
// catch the most common cases
const emphasisNode = parent as Emphasis & Proxy;
const grandparent = emphasisNode.parent;
if (grandparent) {
const indexOfEmphasisNode = (grandparent.node as Parent).children.indexOf(
emphasisNode.node as RootContent,
);
if (indexOfEmphasisNode >= 0) {
const nodeBefore = grandparent.children[indexOfEmphasisNode - 1] as
| (Node & Proxy)
| undefined;
if (nodeBefore?.type === "text") {
const textNode = nodeBefore.node as Text;
if (textNode.value.endsWith("*")) {
const strBefore = textNode.value.slice(0, -1);
textNode.value = strBefore;
const strongNode = u("strong", {
children: emphasisNode.children,
});
(grandparent.node as Parent).children.splice(
indexOfEmphasisNode,
1,
strongNode,
);
return true;
}
}
}
}
}
// Let's check if we have any bold items to close
for (let i = parent.children.length - 1; i >= 0; i--) {
const child = parent.children[i];
if (child.type === "text") {
const textNode = child as Text & Proxy;
const sep = "**";
const index = textNode.value.lastIndexOf(sep);
if (index >= 0) {
let isValidOpening = false;
if (index + sep.length < textNode.value.length) {
const charAfter = textNode.value[index + sep.length];
if (!isWhitespace(charAfter)) {
isValidOpening = true;
}
} else {
if (i < parent.children.length - 1) {
// TODO(drifkin): I'm not sure that this check is strict enough.
// We're trying to detect cases like `**[abc]()` where the char
// after the opening ** is indeed a non-whitespace character. We're
// using the heuristic that there's another item after the current
// one, but I'm not sure if that is good enough. In a well
// constructed tree, there aren't two text nodes in a row, so this
// _seems_ good, but I should think through it more
isValidOpening = true;
}
}
if (isValidOpening) {
// TODO(drifkin): close the bold
const strBefore = textNode.value.slice(0, index);
const strAfter = textNode.value.slice(index + sep.length);
(textNode.node as Text).value = strBefore;
// TODO(drifkin): the node above could be empty in which case we probably want to delete it
const children: PhrasingContent[] = [
...(strAfter.length > 0 ? [u("text", { value: strAfter })] : []),
];
const strongNode: Strong = u("strong", {
children,
});
const nodesAfter = (parent.node as Parent).children.splice(
i + 1,
parent.children.length - i - 1,
strongNode,
);
// TODO(drifkin): this cast seems iffy, should see if we can cast the
// parent instead, which would also help us check some of our
// assumptions
strongNode.children.push(...(nodesAfter as PhrasingContent[]));
return true;
}
}
}
}
return false;
}
function prevSibling(node: Node & Proxy): (Node & Proxy) | null {
const parent = node.parent;
if (parent) {
const index = parent.children.indexOf(node);
return parent.children[index - 1] as Node & Proxy;
}
return null;
}
function isWhitespace(str: string) {
return str.trim() === "";
}
// function debugPrintTreeNoPos(tree: Node) {
// console.log(
// JSON.stringify(
// tree,
// (key, value) => {
// if (key === "position") {
// return undefined;
// }
// return value;
// },
// 2,
// ),
// );
// }
export default remarkStreamingMarkdown;
import { describe, it, expect } from "vitest";
import { parseVRAM, getTotalVRAM } from "./vram";
describe("VRAM Utilities", () => {
describe("parseVRAM", () => {
it("should parse GB (decimal) values correctly", () => {
expect(parseVRAM("1 GB")).toBeCloseTo(1000 / 1024); // ≈0.9765625 GiB
expect(parseVRAM("16.5 GB")).toBeCloseTo(16.5 * (1000 / 1024));
expect(parseVRAM("32GB")).toBeCloseTo(32 * (1000 / 1024));
});
it("should parse GiB (binary) values correctly", () => {
expect(parseVRAM("8 GiB")).toBe(8);
expect(parseVRAM("12.8 GiB")).toBe(12.8);
expect(parseVRAM("24GiB")).toBe(24);
});
it("should convert MB (decimal) to GiB correctly", () => {
expect(parseVRAM("1000 MB")).toBeCloseTo(1000 / (1024 * 1024));
expect(parseVRAM("8192 MB")).toBeCloseTo(8192 / (1024 * 1024));
expect(parseVRAM("512.5 MB")).toBeCloseTo(512.5 / (1024 * 1024));
});
it("should convert MiB (binary) to GiB correctly", () => {
expect(parseVRAM("1024 MiB")).toBe(1);
expect(parseVRAM("2048MiB")).toBe(2);
expect(parseVRAM("512.5 MiB")).toBe(512.5 / 1024);
expect(parseVRAM("8192 MiB")).toBe(8);
});
it("should handle case insensitive units", () => {
expect(parseVRAM("8 gb")).toBeCloseTo(8 * (1000 / 1024));
expect(parseVRAM("8 Gb")).toBeCloseTo(8 * (1000 / 1024));
expect(parseVRAM("8 GiB")).toBe(8);
expect(parseVRAM("1024 mib")).toBe(1);
expect(parseVRAM("1000 mb")).toBeCloseTo(1000 / (1024 * 1024));
});
it("should return null for invalid inputs", () => {
expect(parseVRAM("")).toBeNull();
expect(parseVRAM("invalid")).toBeNull();
expect(parseVRAM("8 TB")).toBeNull();
expect(parseVRAM("GB 8")).toBeNull();
expect(parseVRAM("8")).toBeNull();
});
it("should handle edge cases", () => {
expect(parseVRAM("0 GB")).toBe(0);
expect(parseVRAM("0.1 GiB")).toBe(0.1);
expect(parseVRAM("999999 GiB")).toBe(999999);
});
});
describe("Integration tests", () => {
it("should parse various VRAM formats consistently", () => {
const testCases = ["8 GB", "16 GiB", "1024 MB", "512 MiB"];
testCases.forEach((testCase) => {
const parsed = parseVRAM(testCase);
expect(parsed).not.toBeNull();
expect(typeof parsed).toBe("number");
expect(parsed).toBeGreaterThan(0);
});
});
});
describe("getTotalVRAM", () => {
it("should sum VRAM values from multiple computes", () => {
const computes = [
{ vram: "8 GiB" },
{ vram: "16 GiB" },
{ vram: "4 GiB" },
];
expect(getTotalVRAM(computes)).toBe(28);
});
it("should handle MB and MiB conversions correctly", () => {
const computes = [
{ vram: "1024 MiB" }, // 1 GiB
{ vram: "2048 MiB" }, // 2 GiB
{ vram: "1000 MB" }, // ~0.000953 GiB
];
expect(getTotalVRAM(computes)).toBeCloseTo(3.000953);
});
it("should handle mixed units", () => {
const computes = [
{ vram: "8 GiB" },
{ vram: "1000 MB" }, // ~0.000953 GiB
{ vram: "16 GB" }, // 16 * 0.9765625 GiB
];
expect(getTotalVRAM(computes)).toBeCloseTo(8 + 0.000953 + 15.625);
});
it("should skip invalid VRAM strings", () => {
const computes = [
{ vram: "8 GiB" },
{ vram: "invalid" },
{ vram: "16 GiB" },
{ vram: "" },
];
expect(getTotalVRAM(computes)).toBe(24);
});
it("should handle empty array", () => {
expect(getTotalVRAM([])).toBe(0);
});
it("should handle single compute", () => {
const computes = [{ vram: "12 GiB" }];
expect(getTotalVRAM(computes)).toBe(12);
});
it("should handle decimal values", () => {
const computes = [{ vram: "8.5 GiB" }, { vram: "4.25 GiB" }];
expect(getTotalVRAM(computes)).toBe(12.75);
});
});
});
const GIB_FACTOR: Record<string, number> = {
gib: 1, // 1 GiB = 1 GiB
gb: 1000 / 1024, // 1 GB (decimal) = ~0.9765625 GiB
mib: 1 / 1024, // 1 MiB = 1/1024 GiB
mb: 1 / (1024 * 1024), // 1 MB (decimal) = 1,000,000 bytes = ~9.54e-7 GiB
};
export function parseVRAM(vramString: string): number | null {
if (!vramString) return null;
const match = vramString.match(/^(\d+(?:\.\d+)?)\s*(GiB|GB|MiB|MB)$/i);
if (!match) return null;
const value = parseFloat(match[1]);
const unit = match[2].toLowerCase();
return value * GIB_FACTOR[unit];
}
export function getTotalVRAM(inferenceComputes: { vram: string }[]): number {
let totalVRAM = 0;
for (const compute of inferenceComputes) {
const parsed = parseVRAM(compute.vram);
if (parsed !== null) {
totalVRAM += parsed;
}
}
return totalVRAM;
}
/// <reference types="vite/client" />
/** @type {import('tailwindcss').Config} */
export default {
content: ["./index.html", "./src/**/*.{js,ts,jsx,tsx}"],
theme: {
extend: {
spacing: {
3.5: "0.875rem",
4.5: "1.125rem",
},
colors: {
gray: {
350: "#a1a1aa",
},
},
},
},
plugins: [require("@tailwindcss/typography")],
};
{
"compilerOptions": {
"baseUrl": ".",
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"target": "ES2020",
"useDefineForClassFields": true,
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
"types": [],
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
"jsx": "react-jsx",
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true,
"paths": {
"@/gotypes": ["codegen/gotypes.gen.ts"],
"@/*": ["src/*"]
}
},
"include": ["src", "codegen"],
"exclude": ["src/**/*.test.*", "src/**/*.stories.*"]
}
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" },
{ "path": "./tsconfig.stories.json" }
]
}
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
"target": "ES2022",
"lib": ["ES2023"],
"module": "ESNext",
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["vite.config.ts"]
}
{
"extends": "./tsconfig.app.json",
"compilerOptions": {
"types": ["@storybook/react-vite"]
},
"include": ["src/**/*.stories.*"],
"exclude": []
}
import { defineConfig } from "vite";
import react from "@vitejs/plugin-react";
import { TanStackRouterVite } from "@tanstack/router-plugin/vite";
import tailwindcss from "@tailwindcss/vite";
import tsconfigPaths from "vite-tsconfig-paths";
import postcssPresetEnv from "postcss-preset-env";
import { resolve } from "path";
export default defineConfig(() => ({
base: "/",
plugins: [
TanStackRouterVite({ target: "react" }),
react(),
tailwindcss(),
tsconfigPaths(),
],
resolve: {
alias: {
"@/gotypes": resolve(__dirname, "codegen/gotypes.gen.ts"),
"@": resolve(__dirname, "src"),
"micromark-extension-math": "micromark-extension-llm-math",
},
},
css: {
postcss: {
plugins: [
postcssPresetEnv({
stage: 1, // Include more experimental features that Safari 14 needs
browsers: ["Safari >= 14"],
// autoprefixer: false,
features: {
"custom-properties": true, // Let TailwindCSS handle this
"nesting-rules": true,
"logical-properties-and-values": true, // Polyfill logical properties
"media-query-ranges": true, // Modern media query syntax
"color-function": true, // CSS color functions
"double-position-gradients": true,
"gap-properties": true, // This is key for flexbox gap!
"place-properties": true,
"overflow-property": true,
"focus-visible-pseudo-class": true, // Focus-visible support
"focus-within-pseudo-class": true, // Focus-within support
"any-link-pseudo-class": true, // :any-link pseudo-class
"not-pseudo-class": true, // Enhanced :not() support
"dir-pseudo-class": true, // :dir() pseudo-class
"all-property": true, // CSS 'all' property
"image-set-function": true, // image-set() function
"hwb-function": true, // hwb() color function
"lab-function": true, // lab() color function
"oklab-function": true, // oklab() color function
},
}),
],
},
},
build: {
target: "es2017",
},
esbuild: {
target: "es2017",
},
}));
import { defineConfig, mergeConfig } from "vite";
import path from "path";
import baseConfig from "./vite.config";
export default defineConfig((configEnv) =>
mergeConfig(
baseConfig(configEnv),
defineConfig({
resolve: {
alias: {
"@": path.resolve(__dirname, "./src"),
"@/gotypes": path.resolve(__dirname, "./codegen/gotypes.gen.ts"),
},
},
test: {
environment: "node",
globals: true,
},
}),
),
);
/// <reference types="@vitest/browser/providers/playwright" />
//go:build windows || darwin
package ui
import (
"bytes"
"fmt"
"path/filepath"
"slices"
"strings"
"unicode/utf8"
"github.com/ledongthuc/pdf"
)
// convertBytesToText converts raw file bytes to text based on file extension
func convertBytesToText(data []byte, filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
if ext == ".pdf" {
text, err := extractPDFText(data)
if err != nil {
return fmt.Sprintf("[PDF file - %d bytes - failed to extract text: %v]", len(data), err)
}
if strings.TrimSpace(text) == "" {
return fmt.Sprintf("[PDF file - %d bytes - no text content found]", len(data))
}
return text
}
binaryExtensions := []string{
".xlsx", ".pptx", ".zip", ".tar", ".gz", ".rar",
".jpg", ".jpeg", ".png", ".gif", ".bmp", ".svg", ".ico",
".mp3", ".mp4", ".avi", ".mov", ".wmv", ".flv", ".webm",
".exe", ".dll", ".so", ".dylib", ".app", ".dmg", ".pkg",
}
if slices.Contains(binaryExtensions, ext) {
return fmt.Sprintf("[Binary file of type %s - %d bytes]", ext, len(data))
}
if utf8.Valid(data) {
return string(data)
}
// If not valid UTF-8, return a placeholder
return fmt.Sprintf("[Binary file - %d bytes - not valid UTF-8]", len(data))
}
// extractPDFText extracts text content from PDF bytes
func extractPDFText(data []byte) (string, error) {
reader := bytes.NewReader(data)
pdfReader, err := pdf.NewReader(reader, int64(len(data)))
if err != nil {
return "", fmt.Errorf("failed to create PDF reader: %w", err)
}
var textBuilder strings.Builder
numPages := pdfReader.NumPage()
for i := 1; i <= numPages; i++ {
page := pdfReader.Page(i)
if page.V.IsNull() {
continue
}
text, err := page.GetPlainText(nil)
if err != nil {
// Log the error but continue with other pages
continue
}
if strings.TrimSpace(text) != "" {
if textBuilder.Len() > 0 {
textBuilder.WriteString("\n\n--- Page ")
textBuilder.WriteString(fmt.Sprintf("%d", i))
textBuilder.WriteString(" ---\n")
}
textBuilder.WriteString(text)
}
}
return textBuilder.String(), nil
}
//go:build windows || darwin
package responses
import (
"time"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/types/model"
)
type ChatInfo struct {
ID string `json:"id"`
Title string `json:"title"`
UserExcerpt string `json:"userExcerpt"`
CreatedAt time.Time `json:"createdAt" ts_type:"Date" ts_transform:"new Date(__VALUE__)"`
UpdatedAt time.Time `json:"updatedAt" ts_type:"Date" ts_transform:"new Date(__VALUE__)"`
}
type ChatsResponse struct {
ChatInfos []ChatInfo `json:"chatInfos"`
}
type ChatResponse struct {
Chat store.Chat `json:"chat"`
}
type Model struct {
Model string `json:"model"`
Digest string `json:"digest,omitempty"`
ModifiedAt *time.Time `json:"modified_at,omitempty"`
}
type ModelsResponse struct {
Models []Model `json:"models"`
}
type InferenceCompute struct {
Library string `json:"library"`
Variant string `json:"variant"`
Compute string `json:"compute"`
Driver string `json:"driver"`
Name string `json:"name"`
VRAM string `json:"vram"`
}
type InferenceComputeResponse struct {
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
}
type ModelCapabilitiesResponse struct {
Capabilities []model.Capability `json:"capabilities"`
}
// ChatEvent is for regular chat messages and assistant interactions
type ChatEvent struct {
EventName string `json:"eventName" ts_type:"\"chat\" | \"thinking\" | \"assistant_with_tools\" | \"tool_call\" | \"tool\" | \"tool_result\" | \"done\" | \"chat_created\""`
// Chat/Assistant message fields
Content *string `json:"content,omitempty"`
Thinking *string `json:"thinking,omitempty"`
ThinkingTimeStart *time.Time `json:"thinkingTimeStart,omitempty" ts_type:"Date | undefined" ts_transform:"__VALUE__ && new Date(__VALUE__)"`
ThinkingTimeEnd *time.Time `json:"thinkingTimeEnd,omitempty" ts_type:"Date | undefined" ts_transform:"__VALUE__ && new Date(__VALUE__)"`
// Tool-related fields
ToolCalls []store.ToolCall `json:"toolCalls,omitempty"`
ToolCall *store.ToolCall `json:"toolCall,omitempty"`
ToolName *string `json:"toolName,omitempty"`
ToolResult *bool `json:"toolResult,omitempty"`
ToolResultData any `json:"toolResultData,omitempty"`
// Chat creation fields
ChatID *string `json:"chatId,omitempty"`
// Tool state field from the new code
ToolState any `json:"toolState,omitempty"`
}
// DownloadEvent is for model download progress
type DownloadEvent struct {
EventName string `json:"eventName" ts_type:"\"download\""`
Total int64 `json:"total" ts_type:"number"`
Completed int64 `json:"completed" ts_type:"number"`
Done bool `json:"done" ts_type:"boolean"`
}
// ErrorEvent is for error messages
type ErrorEvent struct {
EventName string `json:"eventName" ts_type:"\"error\""`
Error string `json:"error"`
Code string `json:"code,omitempty"` // Optional error code for different error types
Details string `json:"details,omitempty"` // Optional additional details
}
type SettingsResponse struct {
Settings store.Settings `json:"settings"`
}
type HealthResponse struct {
Healthy bool `json:"healthy"`
}
type User struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatarURL"`
Plan string `json:"plan"`
Bio string `json:"bio"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
OverThreshold bool `json:"overThreshold"`
}
type Attachment struct {
Filename string `json:"filename"`
Data string `json:"data,omitempty"` // omitempty = optional, no data = existing file reference
}
type ChatRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Index *int `json:"index,omitempty"`
Attachments []Attachment `json:"attachments,omitempty"`
WebSearch *bool `json:"web_search,omitempty"`
FileTools *bool `json:"file_tools,omitempty"`
ForceUpdate bool `json:"forceUpdate,omitempty"`
Think any `json:"think,omitempty"`
}
type Error struct {
Error string `json:"error"`
}
type ModelUpstreamResponse struct {
Digest string `json:"digest,omitempty"`
PushTime int64 `json:"pushTime"`
Error string `json:"error,omitempty"`
}
// Serializable data for the browser state
type BrowserStateData struct {
PageStack []string `json:"page_stack"` // Sequential list of page URLs
ViewTokens int `json:"view_tokens"` // Number of tokens to show in viewport
URLToPage map[string]*Page `json:"url_to_page"` // URL to page contents
}
// Page represents the contents of a page
type Page struct {
URL string `json:"url"`
Title string `json:"title"`
Text string `json:"text"`
Lines []string `json:"lines"`
Links map[int]string `json:"links,omitempty" ts_type:"Record<number, string>"`
FetchedAt time.Time `json:"fetched_at"`
}
//go:build windows || darwin
// package ui implements a chat interface for Ollama
package ui
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"os"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/server"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
_ "github.com/tkrajina/typescriptify-golang-structs/typescriptify"
)
//go:generate tscriptify -package=github.com/ollama/ollama/app/ui/responses -target=./app/codegen/gotypes.gen.ts responses/types.go
//go:generate npm --prefix ./app run build
var CORS = envconfig.Bool("OLLAMA_CORS")
// OllamaDotCom returns the URL for ollama.com, allowing override via environment variable
var OllamaDotCom = func() string {
if url := os.Getenv("OLLAMA_DOT_COM_URL"); url != "" {
return url
}
return "https://ollama.com"
}()
type statusRecorder struct {
http.ResponseWriter
code int
}
func (r *statusRecorder) Written() bool {
return r.code != 0
}
func (r *statusRecorder) WriteHeader(code int) {
r.code = code
r.ResponseWriter.WriteHeader(code)
}
func (r *statusRecorder) Status() int {
if r.code == 0 {
return http.StatusOK
}
return r.code
}
func (r *statusRecorder) Flush() {
if flusher, ok := r.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Event is a string that represents the type of event being sent to the
// client. It is used in the Server-Sent Events (SSE) protocol to identify
// the type of data being sent.
// The client (template) will use this type in the sse event listener to
// determine how to handle the incoming data. It will also be used in the
// sse-swap htmx event listener to determine how to handle the incoming data.
type Event string
const (
EventChat Event = "chat"
EventComplete Event = "complete"
EventLoading Event = "loading"
EventToolResult Event = "tool_result" // Used for both tool calls and their results
EventThinking Event = "thinking"
EventToolCall Event = "tool_call"
EventDownload Event = "download"
)
type Server struct {
Logger *slog.Logger
Restart func()
Token string
Store *store.Store
ToolRegistry *tools.Registry
Tools bool // if true, the server will use single-turn tools to fulfill the user's request
WebSearch bool // if true, the server will use single-turn browser tool to fulfill the user's request
Agent bool // if true, the server will use multi-turn tools to fulfill the user's request
WorkingDir string // Working directory for all agent operations
// Dev is true if the server is running in development mode
Dev bool
}
func (s *Server) log() *slog.Logger {
if s.Logger == nil {
return slog.Default()
}
return s.Logger
}
// ollamaProxy creates a reverse proxy handler to the Ollama server
func (s *Server) ollamaProxy() http.Handler {
ollamaHost := os.Getenv("OLLAMA_HOST")
if ollamaHost == "" {
ollamaHost = "http://127.0.0.1:11434"
}
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
ollamaHost = "http://" + ollamaHost
}
target, err := url.Parse(ollamaHost)
if err != nil {
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
})
}
s.log().Info("configuring ollama proxy", "target", target.String())
proxy := httputil.NewSingleHostReverseProxy(target)
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
return proxy
}
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
func (s *Server) Handler() http.Handler {
handle := func(f errHandlerFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add CORS headers for dev work
if CORS() {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true")
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
}
// Don't check for token in development mode
if !s.Dev {
cookie, err := r.Cookie("token")
if err != nil {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Token is required"})
return
}
if cookie.Value != s.Token {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Token is required"})
return
}
}
sw := &statusRecorder{ResponseWriter: w}
log := s.log()
level := slog.LevelInfo
start := time.Now()
requestID := fmt.Sprintf("%d", time.Now().UnixNano())
defer func() {
p := recover()
if p != nil {
log = log.With("panic", p, "request_id", requestID)
level = slog.LevelError
// Handle panic with user-friendly error
if !sw.Written() {
s.handleError(sw, fmt.Errorf("internal server error"))
}
}
log.Log(r.Context(), level, "site.serveHTTP",
"http.method", r.Method,
"http.path", r.URL.Path,
"http.pattern", r.Pattern,
"http.status", sw.Status(),
"http.d", time.Since(start),
"request_id", requestID,
"version", version.Version,
)
// let net/http.Server deal with panics
if p != nil {
panic(p)
}
}()
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Version", version.Version)
w.Header().Set("X-Request-ID", requestID)
ctx := r.Context()
if err := f(sw, r); err != nil {
if ctx.Err() != nil {
return
}
level = slog.LevelError
log = log.With("error", err)
s.handleError(sw, err)
}
})
}
mux := http.NewServeMux()
// CORS is handled in `handle`, but we have to match on OPTIONS to handle preflight requests
mux.Handle("OPTIONS /", handle(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
// API routes - handle first to take precedence
mux.Handle("GET /api/v1/chats", handle(s.listChats))
mux.Handle("GET /api/v1/chat/{id}", handle(s.getChat))
mux.Handle("POST /api/v1/chat/{id}", handle(s.chat))
mux.Handle("DELETE /api/v1/chat/{id}", handle(s.deleteChat))
mux.Handle("POST /api/v1/create-chat", handle(s.createChat))
mux.Handle("PUT /api/v1/chat/{id}/rename", handle(s.renameChat))
mux.Handle("GET /api/v1/inference-compute", handle(s.getInferenceCompute))
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings))
// Ollama proxy endpoints
ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/v1/me", handle(s.me))
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
mux.Handle("GET /api/v1/health", handle(s.health))
// React app - catch all non-API routes and serve the React app
mux.Handle("GET /", s.appHandler())
mux.Handle("PUT /", s.appHandler())
mux.Handle("POST /", s.appHandler())
mux.Handle("PATCH /", s.appHandler())
mux.Handle("DELETE /", s.appHandler())
return mux
}
// handleError renders appropriate error responses based on request type
func (s *Server) handleError(w http.ResponseWriter, e error) {
// Preserve CORS headers for API requests
if CORS() {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": e.Error()})
}
// userAgentTransport is a custom RoundTripper that adds the User-Agent header to all requests
type userAgentTransport struct {
base http.RoundTripper
}
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid mutating the original
r := req.Clone(req.Context())
r.Header.Set("User-Agent", userAgent())
return t.base.RoundTrip(r)
}
// httpClient returns an HTTP client that automatically adds the User-Agent header
func (s *Server) httpClient() *http.Client {
return &http.Client{
Timeout: 10 * time.Second,
Transport: &userAgentTransport{
base: http.DefaultTransport,
},
}
}
// doSelfSigned sends a self-signed request to the ollama.com API
func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.Response, error) {
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
// Form the string to sign: METHOD,PATH?ts=TIMESTAMP
signString := fmt.Sprintf("%s,%s?ts=%s", method, path, timestamp)
signature, err := ollamaAuth.Sign(ctx, []byte(signString))
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
}
endpoint := fmt.Sprintf("%s%s?ts=%s", OllamaDotCom, path, timestamp)
req, err := http.NewRequestWithContext(ctx, method, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
return s.httpClient().Do(req)
}
// UserData fetches user data from ollama.com API for the current ollama key
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
if err != nil {
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var user responses.User
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse user response: %w", err)
}
user.AvatarURL = fmt.Sprintf("%s/%s", OllamaDotCom, user.AvatarURL)
storeUser := store.User{
Name: user.Name,
Email: user.Email,
Plan: user.Plan,
}
if err := s.Store.SetUser(storeUser); err != nil {
s.log().Warn("failed to cache user data", "error", err)
}
return &user, nil
}
func waitForServer(ctx context.Context) error {
timeout := time.Now().Add(10 * time.Second)
// TODO: this avoids an error on first load of the app
// however we should either show a loading state or
// wait for the Ollama server to be ready before redirecting
for {
c, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := c.Version(ctx); err == nil {
break
}
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready")
}
time.Sleep(10 * time.Millisecond)
}
return nil
}
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
waitForServer(r.Context())
id, err := uuid.NewV7()
if err != nil {
return fmt.Errorf("failed to generate chat ID: %w", err)
}
json.NewEncoder(w).Encode(map[string]string{"id": id.String()})
return nil
}
func (s *Server) listChats(w http.ResponseWriter, r *http.Request) error {
chats, _ := s.Store.Chats()
chatInfos := make([]responses.ChatInfo, len(chats))
for i, chat := range chats {
chatInfos[i] = chatInfoFromChat(chat)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(responses.ChatsResponse{ChatInfos: chatInfos})
return nil
}
// checkModelUpstream makes a HEAD request to the Ollama registry to get the upstream digest and push time
func (s *Server) checkModelUpstream(ctx context.Context, modelName string, timeout time.Duration) (string, int64, error) {
// Create a context with timeout for the registry check
checkCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Parse model name to get namespace, model, and tag
parts := strings.Split(modelName, ":")
name := parts[0]
tag := "latest"
if len(parts) > 1 {
tag = parts[1]
}
if !strings.Contains(name, "/") {
// If the model name does not contain a slash, assume it's a library model
name = "library/" + name
}
// Check the model in the Ollama registry using HEAD request
url := OllamaDotCom + "/v2/" + name + "/manifests/" + tag
req, err := http.NewRequestWithContext(checkCtx, "HEAD", url, nil)
if err != nil {
return "", 0, err
}
httpClient := s.httpClient()
httpClient.Timeout = timeout
resp, err := httpClient.Do(req)
if err != nil {
return "", 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", 0, fmt.Errorf("registry returned status %d", resp.StatusCode)
}
digest := resp.Header.Get("ollama-content-digest")
if digest == "" {
return "", 0, fmt.Errorf("no digest header found")
}
var pushTime int64
if pushTimeStr := resp.Header.Get("ollama-push-time"); pushTimeStr != "" {
if pt, err := strconv.ParseInt(pushTimeStr, 10, 64); err == nil {
pushTime = pt
}
}
return digest, pushTime, nil
}
// isNetworkError checks if an error string contains common network/connection error patterns
func isNetworkError(errStr string) bool {
networkErrorPatterns := []string{
"connection refused",
"no such host",
"timeout",
"network is unreachable",
"connection reset",
"connection timed out",
"temporary failure",
"dial tcp",
"i/o timeout",
"context deadline exceeded",
"broken pipe",
}
for _, pattern := range networkErrorPatterns {
if strings.Contains(errStr, pattern) {
return true
}
}
return false
}
var ErrNetworkOffline = errors.New("network is offline")
func (s *Server) getError(err error) responses.ErrorEvent {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
return responses.ErrorEvent{
EventName: "error",
Error: "Could not verify you are signed in. Please sign in and try again.",
Code: "cloud_unauthorized",
}
}
errStr := err.Error()
switch {
case strings.Contains(errStr, "402"):
return responses.ErrorEvent{
EventName: "error",
Error: "You've reached your usage limit, please upgrade to continue",
Code: "usage_limit_upgrade",
}
case strings.HasPrefix(errStr, "pull model manifest") && isNetworkError(errStr):
return responses.ErrorEvent{
EventName: "error",
Error: "Unable to download model. Please check your internet connection to download the model for offline use.",
Code: "offline_download_error",
}
case errors.Is(err, ErrNetworkOffline) || strings.Contains(errStr, "operation timed out"):
return responses.ErrorEvent{
EventName: "error",
Error: "Connection lost",
Code: "turbo_connection_lost",
}
}
return responses.ErrorEvent{
EventName: "error",
Error: err.Error(),
}
}
func (s *Server) browserState(chat *store.Chat) (*responses.BrowserStateData, bool) {
if len(chat.BrowserState) > 0 {
var st responses.BrowserStateData
if err := json.Unmarshal(chat.BrowserState, &st); err == nil {
return &st, true
}
}
return nil, false
}
// reconstructBrowserState (legacy): return the latest full browser state stored in messages.
func reconstructBrowserState(messages []store.Message, defaultViewTokens int) *responses.BrowserStateData {
for i := len(messages) - 1; i >= 0; i-- {
msg := messages[i]
if msg.ToolResult == nil {
continue
}
var st responses.BrowserStateData
if err := json.Unmarshal(*msg.ToolResult, &st); err == nil {
if len(st.PageStack) > 0 || len(st.URLToPage) > 0 {
if st.ViewTokens == 0 {
st.ViewTokens = defaultViewTokens
}
return &st
}
}
}
return nil
}
func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/jsonl")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("streaming not supported")
}
if r.Method != "POST" {
return not.Found
}
cid := r.PathValue("id")
createdChat := false
// if cid is the literal string "new", then we create a new chat before
// performing our normal actions
if cid == "new" {
u, err := uuid.NewV7()
if err != nil {
return fmt.Errorf("failed to generate new chat id: %w", err)
}
cid = u.String()
createdChat = true
}
var req responses.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
fmt.Fprintf(os.Stderr, "error unmarshalling body: %v\n", err)
return fmt.Errorf("invalid request body: %w", err)
}
if req.Model == "" {
return fmt.Errorf("empty model")
}
// Don't allow empty messages unless forceUpdate is true
if req.Prompt == "" && !req.ForceUpdate {
return fmt.Errorf("empty message")
}
if createdChat {
// send message to the client that the chat has been created
json.NewEncoder(w).Encode(responses.ChatEvent{
EventName: "chat_created",
ChatID: &cid,
})
flusher.Flush()
}
// Check if this is from a specific message index (e.g. for editing)
idx := -1
if req.Index != nil {
idx = *req.Index
}
// Load chat with attachments since we need them for processing
chat, err := s.Store.ChatWithOptions(cid, true)
if err != nil {
if !errors.Is(err, not.Found) {
return err
}
chat = store.NewChat(cid)
}
// Only add user message if not forceUpdate
if !req.ForceUpdate {
var messageOptions *store.MessageOptions
if len(req.Attachments) > 0 {
storeAttachments := make([]store.File, 0, len(req.Attachments))
for _, att := range req.Attachments {
if att.Data == "" {
// This is an existing file reference - keep it from the original message
if idx >= 0 && idx < len(chat.Messages) {
originalMessage := chat.Messages[idx]
// Find the file by filename in the original message
for _, originalFile := range originalMessage.Attachments {
if originalFile.Filename == att.Filename {
storeAttachments = append(storeAttachments, originalFile)
break
}
}
}
} else {
// This is a new file - decode base64 data
data, err := base64.StdEncoding.DecodeString(att.Data)
if err != nil {
s.log().Error("failed to decode attachment data", "error", err, "filename", att.Filename)
continue
}
storeAttachments = append(storeAttachments, store.File{
Filename: att.Filename,
Data: data,
})
}
}
messageOptions = &store.MessageOptions{
Attachments: storeAttachments,
}
}
userMsg := store.NewMessage("user", req.Prompt, messageOptions)
if idx >= 0 && idx < len(chat.Messages) {
// Generate from specified message: truncate and replace
chat.Messages = chat.Messages[:idx]
chat.Messages = append(chat.Messages, userMsg)
} else {
// Normal mode: append new message
chat.Messages = append(chat.Messages, userMsg)
}
if err := s.Store.SetChat(*chat); err != nil {
return err
}
}
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
_, cancelLoading := context.WithCancel(ctx)
loading := false
c, err := api.ClientFromEnvironment()
if err != nil {
cancelLoading()
return err
}
// Check if the model exists locally by trying to show it
// TODO (jmorganca): skip this round trip and instead just act
// on a 404 error on chat
_, err = c.Show(ctx, &api.ShowRequest{Model: req.Model})
if err != nil || req.ForceUpdate {
// Create an empty assistant message to store the model information
// This will be overwritten when the model responds
chat.Messages = append(chat.Messages, store.NewMessage("assistant", "", &store.MessageOptions{Model: req.Model}))
if err := s.Store.SetChat(*chat); err != nil {
cancelLoading()
return err
}
// Send download progress events while the model is being pulled
// TODO (jmorganca): this only shows the largest digest, but we
// should show the progress for the total size of the download
var largestDigest string
var largestTotal int64
err = c.Pull(ctx, &api.PullRequest{Model: req.Model}, func(progress api.ProgressResponse) error {
if progress.Digest != "" && progress.Total > largestTotal {
largestDigest = progress.Digest
largestTotal = progress.Total
}
if progress.Digest != "" && progress.Digest == largestDigest {
progressEvent := responses.DownloadEvent{
EventName: string(EventDownload),
Total: progress.Total,
Completed: progress.Completed,
Done: false,
}
if err := json.NewEncoder(w).Encode(progressEvent); err != nil {
return err
}
flusher.Flush()
}
return nil
})
if err != nil {
s.log().Error("model download error", "error", err, "model", req.Model)
errorEvent := s.getError(err)
json.NewEncoder(w).Encode(errorEvent)
flusher.Flush()
cancelLoading()
return fmt.Errorf("failed to download model: %w", err)
}
if err := json.NewEncoder(w).Encode(responses.DownloadEvent{
EventName: string(EventDownload),
Completed: largestTotal,
Total: largestTotal,
Done: true,
}); err != nil {
cancelLoading()
return err
}
flusher.Flush()
// If forceUpdate, we're done after updating the model
if req.ForceUpdate {
json.NewEncoder(w).Encode(responses.ChatEvent{EventName: "done"})
flusher.Flush()
cancelLoading()
return nil
}
}
loading = true
defer cancelLoading()
// Check the model capabilities
details, err := c.Show(ctx, &api.ShowRequest{Model: req.Model})
if err != nil || details == nil {
errorEvent := s.getError(err)
json.NewEncoder(w).Encode(errorEvent)
flusher.Flush()
s.log().Error("failed to show model details", "error", err, "model", req.Model)
return nil
}
think := slices.Contains(details.Capabilities, model.CapabilityThinking)
var thinkValue any
if req.Think != nil {
thinkValue = req.Think
} else {
thinkValue = think
}
// Check if the last user message has attachments
// TODO (parthsareen): this logic will change with directory drag and drop
hasAttachments := false
if len(chat.Messages) > 0 {
lastMsg := chat.Messages[len(chat.Messages)-1]
if lastMsg.Role == "user" && len(lastMsg.Attachments) > 0 {
hasAttachments = true
}
}
// Check if agent or tools mode is enabled
// Note: Skip agent/tools mode if user has attachments, as the agent doesn't handle file attachments properly
registry := tools.NewRegistry()
var browser *tools.Browser
if !hasAttachments {
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
if WebSearchEnabled {
if supportsBrowserTools(req.Model) {
browserState, ok := s.browserState(chat)
if !ok {
browserState = reconstructBrowserState(chat.Messages, tools.DefaultViewTokens)
}
browser = tools.NewBrowser(browserState)
registry.Register(tools.NewBrowserSearch(browser))
registry.Register(tools.NewBrowserOpen(browser))
registry.Register(tools.NewBrowserFind(browser))
} else if supportsWebSearchTools(req.Model) {
registry.Register(&tools.WebSearch{})
registry.Register(&tools.WebFetch{})
}
}
}
var thinkingTimeStart *time.Time = nil
var thinkingTimeEnd *time.Time = nil
// Request-only assistant tool_calls buffer
// if tool_calls arrive before any assistant text, we keep them here,
// inject them into the next request, and attach on first assistant content/thinking.
var pendingAssistantToolCalls []store.ToolCall
passNum := 1
for {
var toolsExecuted bool
availableTools := registry.AvailableTools()
// If we have pending assistant tool_calls and no assistant yet,
// build the request against a temporary chat that includes a
// request-only assistant with tool_calls inserted BEFORE tool messages
reqChat := chat
if len(pendingAssistantToolCalls) > 0 {
if len(chat.Messages) == 0 || chat.Messages[len(chat.Messages)-1].Role != "assistant" {
temp := *chat
synth := store.NewMessage("assistant", "", &store.MessageOptions{Model: req.Model, ToolCalls: pendingAssistantToolCalls})
insertIdx := len(temp.Messages) - 1
for insertIdx >= 0 && temp.Messages[insertIdx].Role == "tool" {
insertIdx--
}
if insertIdx < 0 {
temp.Messages = append([]store.Message{synth}, temp.Messages...)
} else {
tmp := make([]store.Message, 0, len(temp.Messages)+1)
tmp = append(tmp, temp.Messages[:insertIdx+1]...)
tmp = append(tmp, synth)
tmp = append(tmp, temp.Messages[insertIdx+1:]...)
temp.Messages = tmp
}
reqChat = &temp
}
}
chatReq, err := s.buildChatRequest(reqChat, req.Model, thinkValue, availableTools)
if err != nil {
return err
}
err = c.Chat(ctx, chatReq, func(res api.ChatResponse) error {
if loading {
// Remove the loading indicator on first token
cancelLoading()
loading = false
}
// Start thinking timer on first thinking content or after tool call when thinking again
if res.Message.Thinking != "" && (thinkingTimeStart == nil || thinkingTimeEnd != nil) {
now := time.Now()
thinkingTimeStart = &now
thinkingTimeEnd = nil
}
if res.Message.Content == "" && res.Message.Thinking == "" && len(res.Message.ToolCalls) == 0 {
return nil
}
event := EventChat
if thinkingTimeStart != nil && res.Message.Content == "" && len(res.Message.ToolCalls) == 0 {
event = EventThinking
}
if len(res.Message.ToolCalls) > 0 {
event = EventToolCall
}
if event == EventToolCall && thinkingTimeStart != nil && thinkingTimeEnd == nil {
now := time.Now()
thinkingTimeEnd = &now
}
if event == EventChat && thinkingTimeStart != nil && thinkingTimeEnd == nil && res.Message.Content != "" {
now := time.Now()
thinkingTimeEnd = &now
}
json.NewEncoder(w).Encode(chatEventFromApiChatResponse(res, thinkingTimeStart, thinkingTimeEnd))
flusher.Flush()
switch event {
case EventToolCall:
if thinkingTimeEnd != nil {
if len(chat.Messages) > 0 && chat.Messages[len(chat.Messages)-1].Role == "assistant" {
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.ThinkingTimeEnd = thinkingTimeEnd
lastMsg.UpdatedAt = time.Now()
s.Store.UpdateLastMessage(chat.ID, *lastMsg)
}
thinkingTimeStart = nil
thinkingTimeEnd = nil
}
// attach tool_calls to an existing assistant if present,
// otherwise (for standalone web_search/web_fetch) buffer for request-only injection.
if len(res.Message.ToolCalls) > 0 {
if len(chat.Messages) > 0 && chat.Messages[len(chat.Messages)-1].Role == "assistant" {
toolCalls := make([]store.ToolCall, len(res.Message.ToolCalls))
for i, tc := range res.Message.ToolCalls {
argsJSON, _ := json.Marshal(tc.Function.Arguments)
toolCalls[i] = store.ToolCall{
Type: "function",
Function: store.ToolFunction{
Name: tc.Function.Name,
Arguments: string(argsJSON),
},
}
}
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.ToolCalls = toolCalls
if err := s.Store.UpdateLastMessage(chat.ID, *lastMsg); err != nil {
return err
}
} else {
onlyStandalone := true
for _, tc := range res.Message.ToolCalls {
if !(tc.Function.Name == "web_search" || tc.Function.Name == "web_fetch") {
onlyStandalone = false
break
}
}
if onlyStandalone {
toolCalls := make([]store.ToolCall, len(res.Message.ToolCalls))
for i, tc := range res.Message.ToolCalls {
argsJSON, _ := json.Marshal(tc.Function.Arguments)
toolCalls[i] = store.ToolCall{
Type: "function",
Function: store.ToolFunction{
Name: tc.Function.Name,
Arguments: string(argsJSON),
},
}
}
synth := store.NewMessage("assistant", "", &store.MessageOptions{Model: req.Model, ToolCalls: toolCalls})
chat.Messages = append(chat.Messages, synth)
if err := s.Store.AppendMessage(chat.ID, synth); err != nil {
return err
}
// clear buffer to avoid-injecting again
pendingAssistantToolCalls = nil
}
}
}
for _, toolCall := range res.Message.ToolCalls {
// continues loop as tools were executed
toolsExecuted = true
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
errContent := fmt.Sprintf("Error: %v", err)
toolErrMsg := store.NewMessage("tool", errContent, nil)
toolErrMsg.ToolName = toolCall.Function.Name
chat.Messages = append(chat.Messages, toolErrMsg)
if err := s.Store.AppendMessage(chat.ID, toolErrMsg); err != nil {
return err
}
// Emit tool error event
toolResult := true
json.NewEncoder(w).Encode(responses.ChatEvent{
EventName: "tool",
Content: &errContent,
ToolName: &toolCall.Function.Name,
})
flusher.Flush()
json.NewEncoder(w).Encode(responses.ChatEvent{
EventName: "tool_result",
Content: &errContent,
ToolName: &toolCall.Function.Name,
ToolResult: &toolResult,
ToolResultData: nil, // No result data for errors
})
flusher.Flush()
continue
}
var tr json.RawMessage
if strings.HasPrefix(toolCall.Function.Name, "browser.search") {
// For standalone web_search, ensure the tool message has readable content
// so the second-pass model can consume results, while keeping browser state flow intact.
// We still persist tool msg with content below.
// (No browser state update needed for standalone.)
} else if strings.HasPrefix(toolCall.Function.Name, "browser") {
stateBytes, err := json.Marshal(browser.State())
if err != nil {
return fmt.Errorf("failed to marshal browser state: %w", err)
}
if err := s.Store.UpdateChatBrowserState(chat.ID, json.RawMessage(stateBytes)); err != nil {
return fmt.Errorf("failed to persist browser state to chat: %w", err)
}
// tool result is not added to the tool message for the browser tool
} else {
var err error
tr, err = json.Marshal(result)
if err != nil {
return fmt.Errorf("failed to marshal tool result: %w", err)
}
}
// ensure tool message sent back to the model has content (if empty, use a sensible fallback)
modelContent := content
if toolCall.Function.Name == "web_fetch" && modelContent == "" {
if str, ok := result.(string); ok {
modelContent = str
}
}
if modelContent == "" && len(tr) > 0 {
s.log().Debug("tool message empty, sending json result")
modelContent = string(tr)
}
toolMsg := store.NewMessage("tool", modelContent, &store.MessageOptions{
ToolResult: &tr,
})
toolMsg.ToolName = toolCall.Function.Name
chat.Messages = append(chat.Messages, toolMsg)
s.Store.AppendMessage(chat.ID, toolMsg)
// Emit tool message event (matching agent pattern)
toolResult := true
json.NewEncoder(w).Encode(responses.ChatEvent{
EventName: "tool",
Content: &content,
ToolName: &toolCall.Function.Name,
})
flusher.Flush()
var toolState any = nil
if browser != nil {
toolState = browser.State()
}
// Stream tool result to frontend
json.NewEncoder(w).Encode(responses.ChatEvent{
EventName: "tool_result",
Content: &content,
ToolName: &toolCall.Function.Name,
ToolResult: &toolResult,
ToolResultData: result,
ToolState: toolState,
})
flusher.Flush()
}
case EventChat:
// Append the new message to the chat history
if len(chat.Messages) == 0 || chat.Messages[len(chat.Messages)-1].Role != "assistant" {
newMsg := store.NewMessage("assistant", "", &store.MessageOptions{Model: req.Model})
chat.Messages = append(chat.Messages, newMsg)
// Append new message to database
if err := s.Store.AppendMessage(chat.ID, newMsg); err != nil {
return err
}
// Attach any buffered tool_calls (request-only) now that assistant has started
if len(pendingAssistantToolCalls) > 0 {
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.ToolCalls = pendingAssistantToolCalls
pendingAssistantToolCalls = nil
if err := s.Store.UpdateLastMessage(chat.ID, *lastMsg); err != nil {
return err
}
}
}
// Append token to last assistant message & persist
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.Content += res.Message.Content
lastMsg.UpdatedAt = time.Now()
// Update thinking time fields
if thinkingTimeStart != nil {
lastMsg.ThinkingTimeStart = thinkingTimeStart
}
if thinkingTimeEnd != nil {
lastMsg.ThinkingTimeEnd = thinkingTimeEnd
}
// Use optimized update for streaming
if err := s.Store.UpdateLastMessage(chat.ID, *lastMsg); err != nil {
return err
}
case EventThinking:
// Persist thinking content
if len(chat.Messages) == 0 || chat.Messages[len(chat.Messages)-1].Role != "assistant" {
newMsg := store.NewMessage("assistant", "", &store.MessageOptions{
Model: req.Model,
Thinking: res.Message.Thinking,
})
chat.Messages = append(chat.Messages, newMsg)
// Append new message to database
if err := s.Store.AppendMessage(chat.ID, newMsg); err != nil {
return err
}
// Attach any buffered tool_calls now that assistant exists
if len(pendingAssistantToolCalls) > 0 {
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.ToolCalls = pendingAssistantToolCalls
pendingAssistantToolCalls = nil
if err := s.Store.UpdateLastMessage(chat.ID, *lastMsg); err != nil {
return err
}
}
} else {
// Update thinking content of existing message
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.Thinking += res.Message.Thinking
lastMsg.UpdatedAt = time.Now()
// Update thinking time fields
if thinkingTimeStart != nil {
lastMsg.ThinkingTimeStart = thinkingTimeStart
}
if thinkingTimeEnd != nil {
lastMsg.ThinkingTimeEnd = thinkingTimeEnd
}
// Use optimized update for streaming
if err := s.Store.UpdateLastMessage(chat.ID, *lastMsg); err != nil {
return err
}
}
}
return nil
})
if err != nil {
s.log().Error("chat stream error", "error", err)
errorEvent := s.getError(err)
json.NewEncoder(w).Encode(errorEvent)
flusher.Flush()
return nil
}
// If no tools were executed, exit the loop
if !toolsExecuted {
break
}
passNum++
}
// handle cases where thinking started but didn't finish
// this can happen if the client disconnects or the request is cancelled
// TODO (jmorganca): this should be merged with code above
if thinkingTimeStart != nil && thinkingTimeEnd == nil {
now := time.Now()
thinkingTimeEnd = &now
if len(chat.Messages) > 0 && chat.Messages[len(chat.Messages)-1].Role == "assistant" {
lastMsg := &chat.Messages[len(chat.Messages)-1]
lastMsg.ThinkingTimeEnd = thinkingTimeEnd
lastMsg.UpdatedAt = time.Now()
s.Store.UpdateLastMessage(chat.ID, *lastMsg)
}
}
json.NewEncoder(w).Encode(responses.ChatEvent{EventName: "done"})
flusher.Flush()
if len(chat.Messages) > 0 {
chat.Messages[len(chat.Messages)-1].Stream = false
}
return s.Store.SetChat(*chat)
}
func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
}
chat, err := s.Store.Chat(cid)
if err != nil {
// Return empty chat if not found
data := responses.ChatResponse{
Chat: store.Chat{},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(data)
return nil //nolint:nilerr
}
// fill missing tool_name on tool messages (from previous tool_calls) so labels don’t flip after reload.
if chat != nil && len(chat.Messages) > 0 {
for i := range chat.Messages {
if chat.Messages[i].Role == "tool" && chat.Messages[i].ToolName == "" && chat.Messages[i].ToolResult != nil {
for j := i - 1; j >= 0; j-- {
if chat.Messages[j].Role == "assistant" && len(chat.Messages[j].ToolCalls) > 0 {
last := chat.Messages[j].ToolCalls[len(chat.Messages[j].ToolCalls)-1]
if last.Function.Name != "" {
chat.Messages[i].ToolName = last.Function.Name
}
break
}
}
}
}
}
browserState, ok := s.browserState(chat)
if !ok {
browserState = reconstructBrowserState(chat.Messages, tools.DefaultViewTokens)
}
// clear the text and lines of all pages as it is not needed for rendering
if browserState != nil {
for _, page := range browserState.URLToPage {
page.Lines = nil
page.Text = ""
}
if cleanedState, err := json.Marshal(browserState); err == nil {
chat.BrowserState = json.RawMessage(cleanedState)
}
}
data := responses.ChatResponse{
Chat: *chat,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(data)
return nil
}
func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
}
var req struct {
Title string `json:"title"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
// Get the chat without loading attachments (we only need to update the title)
chat, err := s.Store.ChatWithOptions(cid, false)
if err != nil {
return fmt.Errorf("chat not found: %w", err)
}
// Update the title
chat.Title = req.Title
if err := s.Store.SetChat(*chat); err != nil {
return fmt.Errorf("failed to update chat: %w", err)
}
// Return the updated chat info
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(chatInfoFromChat(*chat))
return nil
}
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
}
// Check if the chat exists (no need to load attachments)
_, err := s.Store.ChatWithOptions(cid, false)
if err != nil {
if errors.Is(err, not.Found) {
w.WriteHeader(http.StatusNotFound)
return fmt.Errorf("chat not found")
}
return fmt.Errorf("failed to get chat: %w", err)
}
// Delete the chat
if err := s.Store.DeleteChat(cid); err != nil {
return fmt.Errorf("failed to delete chat: %w", err)
}
w.WriteHeader(http.StatusOK)
return nil
}
// TODO(parthsareen): consolidate events within the function
func chatEventFromApiChatResponse(res api.ChatResponse, thinkingTimeStart *time.Time, thinkingTimeEnd *time.Time) responses.ChatEvent {
// If there are tool calls, send assistant_with_tools event
if len(res.Message.ToolCalls) > 0 {
// Convert API tool calls to store tool calls
storeToolCalls := make([]store.ToolCall, len(res.Message.ToolCalls))
for i, tc := range res.Message.ToolCalls {
argsJSON, _ := json.Marshal(tc.Function.Arguments)
storeToolCalls[i] = store.ToolCall{
Type: "function",
Function: store.ToolFunction{
Name: tc.Function.Name,
Arguments: string(argsJSON),
},
}
}
var content *string
if res.Message.Content != "" {
content = &res.Message.Content
}
var thinking *string
if res.Message.Thinking != "" {
thinking = &res.Message.Thinking
}
return responses.ChatEvent{
EventName: "assistant_with_tools",
Content: content,
Thinking: thinking,
ToolCalls: storeToolCalls,
ThinkingTimeStart: thinkingTimeStart,
ThinkingTimeEnd: thinkingTimeEnd,
}
}
// Otherwise, send regular chat event
var content *string
if res.Message.Content != "" {
content = &res.Message.Content
}
var thinking *string
if res.Message.Thinking != "" {
thinking = &res.Message.Thinking
}
return responses.ChatEvent{
EventName: "chat",
Content: content,
Thinking: thinking,
ThinkingTimeStart: thinkingTimeStart,
ThinkingTimeEnd: thinkingTimeEnd,
}
}
func chatInfoFromChat(chat store.Chat) responses.ChatInfo {
userExcerpt := ""
var updatedAt time.Time
for _, msg := range chat.Messages {
// extract the first user message as the user excerpt
if msg.Role == "user" && userExcerpt == "" {
userExcerpt = msg.Content
}
// update the updated at time
if msg.UpdatedAt.After(updatedAt) {
updatedAt = msg.UpdatedAt
}
}
return responses.ChatInfo{
ID: chat.ID,
Title: chat.Title,
UserExcerpt: userExcerpt,
CreatedAt: chat.CreatedAt,
UpdatedAt: updatedAt,
}
}
func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
settings, err := s.Store.Settings()
if err != nil {
return fmt.Errorf("failed to load settings: %w", err)
}
// set default models directory if not set
if settings.Models == "" {
settings.Models = envconfig.Models()
}
// set default context length if not set
if settings.ContextLength == 0 {
settings.ContextLength = 4096
}
// Include current runtime settings
settings.Agent = s.Agent
settings.Tools = s.Tools
settings.WorkingDir = s.WorkingDir
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(responses.SettingsResponse{
Settings: settings,
})
}
func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
old, err := s.Store.Settings()
if err != nil {
return fmt.Errorf("failed to load settings: %w", err)
}
var settings store.Settings
if err := json.NewDecoder(r.Body).Decode(&settings); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := s.Store.SetSettings(settings); err != nil {
return fmt.Errorf("failed to save settings: %w", err)
}
if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models ||
old.Expose != settings.Expose {
s.Restart()
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(responses.SettingsResponse{
Settings: settings,
})
}
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
user, err := s.UserData(r.Context())
if err != nil {
// If fetching from API fails, try to return cached user data if available
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
s.log().Info("API request failed, returning cached user data", "error", err)
responseUser := &responses.User{
Name: cachedUser.Name,
Email: cachedUser.Email,
Plan: cachedUser.Plan,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responseUser)
}
s.log().Error("failed to get user data", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get user data",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(user)
}
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
if err := s.Store.ClearUser(); err != nil {
s.log().Warn("failed to clear cached user data", "error", err)
}
// Get the SSH public key to encode for the delete request
pubKey, err := ollamaAuth.GetPublicKey()
if err != nil {
s.log().Error("failed to get public key", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get public key",
})
}
// Encode the key using base64 URL encoding
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
if err != nil {
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.log().Error("disconnect request failed", "status", resp.StatusCode)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
}
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
if err != nil {
s.log().Error("failed to build connect URL", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to build connect URL",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{
"connect_url": connectURL,
})
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
healthy := false
c, err := api.ClientFromEnvironment()
if err == nil {
if _, err := c.Version(r.Context()); err == nil {
healthy = true
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responses.HealthResponse{
Healthy: healthy,
})
}
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
if err != nil {
s.log().Error("failed to get inference compute", "error", err)
return fmt.Errorf("failed to get inference compute: %w", err)
}
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
for i, ic := range serverInferenceComputes {
inferenceComputes[i] = responses.InferenceCompute{
Library: ic.Library,
Variant: ic.Variant,
Compute: ic.Compute,
Driver: ic.Driver,
Name: ic.Name,
VRAM: ic.VRAM,
}
}
response := responses.InferenceComputeResponse{
InferenceComputes: inferenceComputes,
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return fmt.Errorf("method not allowed")
}
var req struct {
Model string `json:"model"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if req.Model == "" {
return fmt.Errorf("model is required")
}
digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second)
if err != nil {
s.log().Warn("failed to check upstream digest", "error", err, "model", req.Model)
response := responses.ModelUpstreamResponse{
Error: err.Error(),
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
response := responses.ModelUpstreamResponse{
Digest: digest,
PushTime: pushTime,
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func userAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) app/%s Go/%s",
version,
runtime.GOARCH,
runtime.GOOS,
version,
runtime.Version(),
)
}
// convertToOllamaTool converts a tool schema from our tools package format to Ollama API format
func convertToOllamaTool(toolSchema map[string]any) api.Tool {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: getStringFromMap(toolSchema, "name", ""),
Description: getStringFromMap(toolSchema, "description", ""),
},
}
tool.Function.Parameters.Type = "object"
tool.Function.Parameters.Required = []string{}
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
if props, ok := schemaProps["properties"].(map[string]any); ok {
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
for propName, propDef := range props {
if propMap, ok := propDef.(map[string]any); ok {
prop := api.ToolProperty{
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
Description: getStringFromMap(propMap, "description", ""),
}
tool.Function.Parameters.Properties[propName] = prop
}
}
}
if required, ok := schemaProps["required"].([]string); ok {
tool.Function.Parameters.Required = required
} else if requiredAny, ok := schemaProps["required"].([]any); ok {
required := make([]string, len(requiredAny))
for i, r := range requiredAny {
if s, ok := r.(string); ok {
required[i] = s
}
}
tool.Function.Parameters.Required = required
}
}
return tool
}
// getStringFromMap safely gets a string from a map
func getStringFromMap(m map[string]any, key, defaultValue string) string {
if val, ok := m[key].(string); ok {
return val
}
return defaultValue
}
// isImageAttachment checks if a filename is an image file
func isImageAttachment(filename string) bool {
ext := strings.ToLower(filename)
return strings.HasSuffix(ext, ".png") || strings.HasSuffix(ext, ".jpg") || strings.HasSuffix(ext, ".jpeg")
}
// ptr is a convenience function for &literal
func ptr[T any](v T) *T { return &v }
// Browser tools simulate a full browser environment, allowing for actions like searching, opening, and interacting with web pages (e.g., "browser_search", "browser_open", "browser_find"). Currently only gpt-oss models support browser tools.
func supportsBrowserTools(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
}
// Web search tools are simpler, providing only basic web search and fetch capabilities (e.g., "web_search", "web_fetch") without simulating a browser. Currently only qwen3 and deepseek-v3 support web search tools.
func supportsWebSearchTools(model string) bool {
model = strings.ToLower(model)
prefixes := []string{"qwen3", "deepseek-v3"}
for _, p := range prefixes {
if strings.HasPrefix(model, p) {
return true
}
}
return false
}
// buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
var msgs []api.Message
for _, m := range chat.Messages {
// Skip empty messages if present
if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 {
continue
}
apiMsg := api.Message{Role: m.Role, Thinking: m.Thinking}
sb := strings.Builder{}
sb.WriteString(m.Content)
var images []api.ImageData
if m.Role == "user" && len(m.Attachments) > 0 {
for _, a := range m.Attachments {
if isImageAttachment(a.Filename) {
images = append(images, api.ImageData(a.Data))
} else {
content := convertBytesToText(a.Data, a.Filename)
sb.WriteString(fmt.Sprintf("\n--- File: %s ---\n%s\n--- End of %s ---",
a.Filename, content, a.Filename))
}
}
}
apiMsg.Content = sb.String()
apiMsg.Images = images
switch m.Role {
case "assistant":
if len(m.ToolCalls) > 0 {
var toolCalls []api.ToolCall
for _, tc := range m.ToolCalls {
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
s.log().Error("failed to parse tool call arguments", "error", err, "function_name", tc.Function.Name, "arguments", tc.Function.Arguments)
continue
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: tc.Function.Name,
Arguments: args,
},
})
}
apiMsg.ToolCalls = toolCalls
}
case "tool":
apiMsg.Role = "tool"
apiMsg.Content = m.Content
apiMsg.ToolName = m.ToolName
case "user", "system":
// User and system messages are handled normally
default:
// Log unknown roles but still include them
s.log().Debug("unknown message role", "role", m.Role)
}
msgs = append(msgs, apiMsg)
}
var thinkValue *api.ThinkValue
if think != nil {
if boolValue, ok := think.(bool); ok {
thinkValue = &api.ThinkValue{
Value: boolValue,
}
} else if stringValue, ok := think.(string); ok {
thinkValue = &api.ThinkValue{
Value: stringValue,
}
}
}
req := &api.ChatRequest{
Model: model,
Messages: msgs,
Stream: ptr(true),
Think: thinkValue,
}
if len(availableTools) > 0 {
tools := make(api.Tools, len(availableTools))
for i, toolSchema := range availableTools {
tools[i] = convertToOllamaTool(toolSchema)
}
req.Tools = tools
}
return req, nil
}
//go:build windows || darwin
package ui
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/ollama/ollama/app/store"
)
func TestHandlePostApiSettings(t *testing.T) {
tests := []struct {
name string
requested store.Settings
wantErr bool
}{
{
name: "valid settings update - all fields",
requested: store.Settings{
Expose: true,
Browser: true,
Models: "/custom/models",
Agent: true,
Tools: true,
WorkingDir: "/workspace",
},
wantErr: false,
},
{
name: "partial settings update",
requested: store.Settings{
Agent: true,
Tools: false,
WorkingDir: "/new/path",
},
wantErr: false,
},
{
name: "settings with special characters in paths",
requested: store.Settings{
Models: "/path with spaces/models",
WorkingDir: "/tmp/work-dir_123",
Agent: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close() // Ensure database is closed before cleanup
body, err := json.Marshal(tt.requested)
if err != nil {
t.Fatalf("failed to marshal test body: %v", err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
// Set up server with test store
server := &Server{
Store: testStore,
Restart: func() {}, // Mock restart function for tests
}
if err := server.settings(rr, req); (err != nil) != tt.wantErr {
t.Errorf("handlePostApiSettings() error = %v, wantErr %v", err, tt.wantErr)
}
if rr.Code != http.StatusOK {
t.Errorf("handlePostApiSettings() status = %v, want %v", rr.Code, http.StatusOK)
}
// Check settings were saved correctly (if no error expected)
if !tt.wantErr {
savedSettings, err := testStore.Settings()
if err != nil {
t.Errorf("failed to retrieve saved settings: %v", err)
} else {
// Compare field by field, accounting for defaults that may be set by the store
if savedSettings.Expose != tt.requested.Expose {
t.Errorf("Expose: got %v, want %v", savedSettings.Expose, tt.requested.Expose)
}
if savedSettings.Browser != tt.requested.Browser {
t.Errorf("Browser: got %v, want %v", savedSettings.Browser, tt.requested.Browser)
}
if savedSettings.Agent != tt.requested.Agent {
t.Errorf("Agent: got %v, want %v", savedSettings.Agent, tt.requested.Agent)
}
if savedSettings.Tools != tt.requested.Tools {
t.Errorf("Tools: got %v, want %v", savedSettings.Tools, tt.requested.Tools)
}
if savedSettings.WorkingDir != tt.requested.WorkingDir {
t.Errorf("WorkingDir: got %q, want %q", savedSettings.WorkingDir, tt.requested.WorkingDir)
}
// Only check Models if explicitly set in the test case
if tt.requested.Models != "" && savedSettings.Models != tt.requested.Models {
t.Errorf("Models: got %q, want %q", savedSettings.Models, tt.requested.Models)
}
}
}
})
}
}
func TestAuthenticationMiddleware(t *testing.T) {
tests := []struct {
name string
method string
contentType string
tokenCookie string
serverToken string
wantStatus int
wantError string
setupRequest func(*http.Request)
}{
{
name: "missing token cookie",
method: "GET",
tokenCookie: "",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Token is required",
},
{
name: "invalid token value",
method: "GET",
tokenCookie: "wrong-token",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Token is required",
},
{
name: "valid token - GET request",
method: "GET",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusOK,
wantError: "",
},
{
name: "valid token - POST with application/json",
method: "POST",
contentType: "application/json",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusOK,
wantError: "",
},
{
name: "POST without Content-Type header",
method: "POST",
contentType: "",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Content-Type must be application/json",
},
{
name: "POST with wrong Content-Type",
method: "POST",
contentType: "text/plain",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Content-Type must be application/json",
},
{
name: "OPTIONS request (CORS preflight) - should bypass auth",
method: "OPTIONS",
tokenCookie: "",
serverToken: "test-token-123",
wantStatus: http.StatusOK,
wantError: "",
setupRequest: func(r *http.Request) {
// Simulate CORS being enabled
// Note: This assumes CORS() returns true in test environment
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test handler that just returns 200 OK if auth passes
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
})
// Create server with test token
server := &Server{
Token: tt.serverToken,
}
// Get the authentication middleware by calling Handler()
// We need to wrap our test handler with the auth middleware
handler := server.Handler()
// Create a test router to simulate the authentication middleware
mux := http.NewServeMux()
mux.Handle("/test", handler)
// But since Handler() returns the full router, we'll need a different approach
// Let's create a minimal handler that includes just the auth logic
authHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add CORS headers for dev work
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method == "POST" && r.Header.Get("Content-Type") != "application/json" {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Content-Type must be application/json"})
return
}
cookie, err := r.Cookie("token")
if err != nil {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Token is required"})
return
}
if cookie.Value != server.Token {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Token is required"})
return
}
// If auth passes, call the test handler
testHandler.ServeHTTP(w, r)
})
// Create test request
req := httptest.NewRequest(tt.method, "/test", nil)
// Set Content-Type if provided
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
// Set token cookie if provided
if tt.tokenCookie != "" {
req.AddCookie(&http.Cookie{
Name: "token",
Value: tt.tokenCookie,
})
}
// Run any additional setup
if tt.setupRequest != nil {
tt.setupRequest(req)
}
// Create response recorder
rr := httptest.NewRecorder()
// Serve the request
authHandler.ServeHTTP(rr, req)
// Check status code
if rr.Code != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.wantStatus)
}
// Check error message if expected
if tt.wantError != "" {
var response map[string]string
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response body: %v", err)
}
if response["error"] != tt.wantError {
t.Errorf("handler returned wrong error message: got %v want %v", response["error"], tt.wantError)
}
}
})
}
}
func TestUserAgent(t *testing.T) {
ua := userAgent()
// The userAgent function should return a string in the format:
// "ollama/version (arch os) app/version Go/goversion"
// Example: "ollama/v0.1.28 (amd64 darwin) Go/go1.21.0"
if ua == "" {
t.Fatal("userAgent returned empty string")
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", ua)
// This is a copy of the logic ollama.com uses to parse the user agent
clientInfoFromRequest := func(r *http.Request) struct {
Product string
Version string
OS string
Arch string
AppVersion string
} {
product, rest, _ := strings.Cut(r.UserAgent(), " ")
client, version, ok := strings.Cut(product, "/")
if !ok {
return struct {
Product string
Version string
OS string
Arch string
AppVersion string
}{}
}
if version != "" && version[0] != 'v' {
version = "v" + version
}
arch, rest, _ := strings.Cut(rest, " ")
arch = strings.Trim(arch, "(")
os, rest, _ := strings.Cut(rest, ")")
var appVersion string
if strings.Contains(rest, "app/") {
_, appPart, found := strings.Cut(rest, "app/")
if found {
appVersion = strings.Fields(strings.TrimSpace(appPart))[0]
if appVersion != "" && appVersion[0] != 'v' {
appVersion = "v" + appVersion
}
}
}
return struct {
Product string
Version string
OS string
Arch string
AppVersion string
}{
Product: client,
Version: version,
OS: os,
Arch: arch,
AppVersion: appVersion,
}
}
info := clientInfoFromRequest(req)
if info.Product != "ollama" {
t.Errorf("Expected Product to be 'ollama', got '%s'", info.Product)
}
if info.Version != "" && info.Version[0] != 'v' {
t.Errorf("Expected Version to start with 'v', got '%s'", info.Version)
}
expectedOS := runtime.GOOS
if info.OS != expectedOS {
t.Errorf("Expected OS to be '%s', got '%s'", expectedOS, info.OS)
}
expectedArch := runtime.GOARCH
if info.Arch != expectedArch {
t.Errorf("Expected Arch to be '%s', got '%s'", expectedArch, info.Arch)
}
if info.AppVersion != "" && info.AppVersion[0] != 'v' {
t.Errorf("Expected AppVersion to start with 'v', got '%s'", info.AppVersion)
}
t.Logf("User Agent: %s", ua)
t.Logf("Parsed - Product: %s, Version: %s, OS: %s, Arch: %s",
info.Product, info.Version, info.OS, info.Arch)
}
func TestUserAgentTransport(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte(r.Header.Get("User-Agent")))
}))
defer ts.Close()
server := &Server{}
client := server.httpClient()
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()
// In this case the User-Agent is the response body, as the server just echoes it back
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
receivedUA := string(body)
expectedUA := userAgent()
if receivedUA != expectedUA {
t.Errorf("User-Agent mismatch\nExpected: %s\nReceived: %s", expectedUA, receivedUA)
}
if !strings.HasPrefix(receivedUA, "ollama/") {
t.Errorf("User-Agent should start with 'ollama/', got: %s", receivedUA)
}
t.Logf("User-Agent transport successfully set: %s", receivedUA)
}
package lifecycle //go:build windows || darwin
package updater
import ( import (
"context" "context"
...@@ -19,14 +21,24 @@ import ( ...@@ -19,14 +21,24 @@ import (
"strings" "strings"
"time" "time"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/version"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/version"
) )
var ( var (
UpdateCheckURLBase = "https://ollama.com/api/update" UpdateCheckURLBase = "https://ollama.com/api/update"
UpdateDownloaded = false UpdateDownloaded = false
UpdateCheckInterval = 60 * 60 * time.Second UpdateCheckInterval = 60 * 60 * time.Second
UpdateCheckInitialDelay = 3 * time.Second // 30 * time.Second
UpdateStageDir string
UpgradeLogFile string
UpgradeMarkerFile string
Installer string
UserAgentOS string
VerifyDownload func() error
) )
// TODO - maybe move up to the API package? // TODO - maybe move up to the API package?
...@@ -35,7 +47,7 @@ type UpdateResponse struct { ...@@ -35,7 +47,7 @@ type UpdateResponse struct {
UpdateVersion string `json:"version"` UpdateVersion string `json:"version"`
} }
func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
var updateResp UpdateResponse var updateResp UpdateResponse
requestURL, err := url.Parse(UpdateCheckURLBase) requestURL, err := url.Parse(UpdateCheckURLBase)
...@@ -49,18 +61,29 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { ...@@ -49,18 +61,29 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
query.Add("version", version.Version) query.Add("version", version.Version)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
nonce, err := auth.NewNonce(rand.Reader, 16) // The original macOS app used to use the device ID
if err != nil { // to check for updates so include it if present
return false, updateResp if runtime.GOOS == "darwin" {
if id, err := u.Store.ID(); err == nil && id != "" {
query.Add("id", id)
}
} }
query.Add("nonce", nonce) var signature string
requestURL.RawQuery = query.Encode()
data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI())) nonce, err := auth.NewNonce(rand.Reader, 16)
signature, err := auth.Sign(ctx, data)
if err != nil { if err != nil {
return false, updateResp // Don't sign if we haven't yet generated a key pair for the server
slog.Debug("unable to generate nonce for update check request", "error", err)
} else {
query.Add("nonce", nonce)
requestURL.RawQuery = query.Encode()
data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI()))
signature, err = auth.Sign(ctx, data)
if err != nil {
slog.Debug("unable to generate signature for update check request", "error", err)
}
} }
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
...@@ -68,10 +91,13 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { ...@@ -68,10 +91,13 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
slog.Warn(fmt.Sprintf("failed to check for update: %s", err)) slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
return false, updateResp return false, updateResp
} }
req.Header.Set("Authorization", signature) if signature != "" {
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) req.Header.Set("Authorization", signature)
}
ua := fmt.Sprintf("ollama/%s %s Go/%s %s", version.Version, runtime.GOARCH, runtime.Version(), UserAgentOS)
req.Header.Set("User-Agent", ua)
slog.Debug("checking for available update", "requestURL", requestURL) slog.Debug("checking for available update", "requestURL", requestURL, "User-Agent", ua)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to check for update: %s", err)) slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
...@@ -104,13 +130,27 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { ...@@ -104,13 +130,27 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
return true, updateResp return true, updateResp
} }
func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Do a head first to check etag info // Do a head first to check etag info
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil { if err != nil {
return err return err
} }
// In case of slow downloads, continue the update check in the background
bgctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
for {
select {
case <-bgctx.Done():
return
case <-time.After(UpdateCheckInterval):
u.checkForUpdate(bgctx)
}
}
}()
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("error checking update: %w", err) return fmt.Errorf("error checking update: %w", err)
...@@ -135,11 +175,11 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { ...@@ -135,11 +175,11 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Check to see if we already have it downloaded // Check to see if we already have it downloaded
_, err = os.Stat(stageFilename) _, err = os.Stat(stageFilename)
if err == nil { if err == nil {
slog.Info("update already downloaded") slog.Info("update already downloaded", "bundle", stageFilename)
return nil return nil
} }
cleanupOldDownloads() cleanupOldDownloads(UpdateStageDir)
req.Method = http.MethodGet req.Method = http.MethodGet
resp, err = http.DefaultClient.Do(req) resp, err = http.DefaultClient.Do(req)
...@@ -176,12 +216,16 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { ...@@ -176,12 +216,16 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
} }
slog.Info("new update downloaded " + stageFilename) slog.Info("new update downloaded " + stageFilename)
if err := VerifyDownload(); err != nil {
_ = os.Remove(stageFilename)
return fmt.Errorf("%s - %s", resp.Request.URL.String(), err)
}
UpdateDownloaded = true UpdateDownloaded = true
return nil return nil
} }
func cleanupOldDownloads() { func cleanupOldDownloads(stageDir string) {
files, err := os.ReadDir(UpdateStageDir) files, err := os.ReadDir(stageDir)
if err != nil && errors.Is(err, os.ErrNotExist) { if err != nil && errors.Is(err, os.ErrNotExist) {
// Expected behavior on first run // Expected behavior on first run
return return
...@@ -190,7 +234,7 @@ func cleanupOldDownloads() { ...@@ -190,7 +234,7 @@ func cleanupOldDownloads() {
return return
} }
for _, file := range files { for _, file := range files {
fullname := filepath.Join(UpdateStageDir, file.Name()) fullname := filepath.Join(stageDir, file.Name())
slog.Debug("cleaning up old download: " + fullname) slog.Debug("cleaning up old download: " + fullname)
err = os.RemoveAll(fullname) err = os.RemoveAll(fullname)
if err != nil { if err != nil {
...@@ -199,22 +243,26 @@ func cleanupOldDownloads() { ...@@ -199,22 +243,26 @@ func cleanupOldDownloads() {
} }
} }
func StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) { type Updater struct {
Store *store.Store
}
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
go func() { go func() {
// Don't blast an update message immediately after startup // Don't blast an update message immediately after startup
// time.Sleep(30 * time.Second) time.Sleep(UpdateCheckInitialDelay)
time.Sleep(3 * time.Second) slog.Info("beginning update checker", "interval", UpdateCheckInterval)
for { for {
available, resp := IsNewReleaseAvailable(ctx) available, resp := u.checkForUpdate(ctx)
if available { if available {
err := DownloadNewRelease(ctx, resp) err := u.DownloadNewRelease(ctx, resp)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err)) slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} } else {
err = cb(resp.UpdateVersion) err = cb(resp.UpdateVersion)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err)) slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
} }
} }
select { select {
......
package updater
// #cgo CFLAGS: -x objective-c
// #cgo LDFLAGS: -framework Webkit -framework Cocoa -framework LocalAuthentication -framework ServiceManagement
// #include "updater_darwin.h"
// typedef const char cchar_t;
import "C"
import (
"archive/zip"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/user"
"path/filepath"
"strings"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
var (
appBackupDir string
SystemWidePath = "/Applications/Ollama.app"
)
var BundlePath = func() string {
if bundle := alreadyMoved(); bundle != "" {
return bundle
}
exe, err := os.Executable()
if err != nil {
return ""
}
// We also install this binary in Contents/Frameworks/Squirrel.framework/Versions/A/Squirrel
if filepath.Base(exe) == "Squirrel" &&
filepath.Base(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(exe)))))) == "Contents" {
return filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(exe))))))
}
// Make sure we're in a proper macOS app bundle structure (Contents/MacOS)
if filepath.Base(filepath.Dir(exe)) != "MacOS" ||
filepath.Base(filepath.Dir(filepath.Dir(exe))) != "Contents" {
return ""
}
return filepath.Dir(filepath.Dir(filepath.Dir(exe)))
}()
func init() {
VerifyDownload = verifyDownload
Installer = "Ollama-darwin.zip"
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
var uts unix.Utsname
if err := unix.Uname(&uts); err == nil {
sysname := unix.ByteSliceToString(uts.Sysname[:])
release := unix.ByteSliceToString(uts.Release[:])
UserAgentOS = fmt.Sprintf("%s/%s", sysname, release)
} else {
slog.Warn("unable to determine OS version", "error", err)
UserAgentOS = "Darwin"
}
// TODO handle failure modes here, and developer mode better...
// Executable = Ollama.app/Contents/MacOS/Ollama
UpgradeLogFile = filepath.Join(home, ".ollama", "logs", "upgrade.log")
cacheDir, err := os.UserCacheDir()
if err != nil {
slog.Warn("unable to determine user cache dir, falling back to tmpdir", "error", err)
cacheDir = os.TempDir()
}
appDataDir := filepath.Join(cacheDir, "ollama")
UpgradeMarkerFile = filepath.Join(appDataDir, "upgraded")
appBackupDir = filepath.Join(appDataDir, "backup")
UpdateStageDir = filepath.Join(appDataDir, "updates")
}
func DoUpgrade(interactive bool) error {
// TODO use UpgradeLogFile to record the upgrade details from->to version, etc.
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
}
slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile)
// TODO - in the future, consider shutting down the backend server now to give it
// time to drain connections and stop allowing new connections while we perform the
// actual upgrade to reduce the overall time to complete
contentsName := filepath.Join(BundlePath, "Contents")
appBackup := filepath.Join(appBackupDir, "Ollama.app")
contentsOldName := filepath.Join(appBackup, "Contents")
// Verify old doesn't exist yet
if _, err := os.Stat(contentsOldName); err == nil {
slog.Error("prior upgrade failed", "backup", contentsOldName)
return fmt.Errorf("prior upgrade failed - please upgrade manually by installing the bundle")
}
if err := os.MkdirAll(appBackupDir, 0o755); err != nil {
return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err)
}
// Verify bundle loads before starting staging process
r, err := zip.OpenReader(bundle)
if err != nil {
return fmt.Errorf("unable to open upgrade bundle %s: %w", bundle, err)
}
defer r.Close()
slog.Debug("temporarily staging old version", "staging", appBackup)
if err := os.Rename(BundlePath, appBackup); err != nil {
if !interactive {
// We don't want to prompt for permission if we're attempting to upgrade at startup
return fmt.Errorf("unable to upgrade in non-interactive mode with permission problems: %w", err)
}
// TODO actually inspect the error and look for permission problems before trying chown
slog.Warn("unable to backup old version due to permission problems, changing ownership", "error", err)
u, err := user.Current()
if err != nil {
return err
}
if !chownWithAuthorization(u.Username) {
return fmt.Errorf("unable to change permissions to complete upgrade")
}
if err := os.Rename(BundlePath, appBackup); err != nil {
return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err)
}
}
// Get ready to try to unwind a partial upgade failure during unzip
// If something goes wrong, we attempt to put the old version back.
anyFailures := false
defer func() {
if anyFailures {
slog.Warn("upgrade failures detected, attempting to revert")
if err := os.RemoveAll(BundlePath); err != nil {
slog.Warn("failed to remove partial upgrade", "path", BundlePath, "error", err)
// At this point, we're basically hosed and the user will need to re-install
return
}
if err := os.Rename(appBackup, BundlePath); err != nil {
slog.Error("failed to revert to prior version", "path", contentsName, "error", err)
}
}
}()
// Bundle contents Ollama.app/Contents/...
links := []*zip.File{}
for _, f := range r.File {
s := strings.SplitN(f.Name, "/", 2)
if len(s) < 2 || s[1] == "" {
slog.Debug("skipping", "file", f.Name)
continue
}
name := s[1]
if strings.HasSuffix(name, "/") {
d := filepath.Join(BundlePath, name)
err := os.MkdirAll(d, 0o755)
if err != nil {
anyFailures = true
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
continue
}
if f.Mode()&os.ModeSymlink != 0 {
// Defer links to the end
links = append(links, f)
continue
}
src, err := f.Open()
if err != nil {
anyFailures = true
return fmt.Errorf("failed to open bundle file %s: %w", name, err)
}
destName := filepath.Join(BundlePath, name)
// Verify directory first
d := filepath.Dir(destName)
if _, err := os.Stat(d); err != nil {
err := os.MkdirAll(d, 0o755)
if err != nil {
anyFailures = true
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
}
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
anyFailures = true
return fmt.Errorf("failed to open output file %s: %w", destName, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
anyFailures = true
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
}
}
for _, f := range links {
s := strings.SplitN(f.Name, "/", 2) // Strip off Ollama.app/
if len(s) < 2 || s[1] == "" {
slog.Debug("skipping link", "file", f.Name)
continue
}
name := s[1]
src, err := f.Open()
if err != nil {
anyFailures = true
return err
}
buf, err := io.ReadAll(src)
if err != nil {
anyFailures = true
return err
}
link := string(buf)
if link[0] == '/' {
anyFailures = true
return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link)
}
// Don't allow links outside of Ollama.app
if strings.HasPrefix(filepath.Join(filepath.Dir(name), link), "..") {
anyFailures = true
return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link)
}
if err = os.Symlink(link, filepath.Join(BundlePath, name)); err != nil {
anyFailures = true
return err
}
}
f, err := os.OpenFile(UpgradeMarkerFile, os.O_RDONLY|os.O_CREATE, 0o666)
if err != nil {
slog.Warn("unable to create marker file", "file", UpgradeMarkerFile, "error", err)
}
f.Close()
// Make sure to remove the staged download now that we succeeded so we don't inadvertently try again.
cleanupOldDownloads(UpdateStageDir)
return nil
}
func DoPostUpgradeCleanup() error {
slog.Debug("post upgrade cleanup", "backup", appBackupDir)
err := os.RemoveAll(appBackupDir)
if err != nil {
return err
}
slog.Debug("post upgrade cleanup", "old", UpgradeMarkerFile)
return os.Remove(UpgradeMarkerFile)
}
func verifyDownload() error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
}
slog.Debug("verifying update", "bundle", bundle)
// Extract zip file into a temporary location so we can run the cert verification routines
dir, err := os.MkdirTemp("", "ollama_update_verify")
if err != nil {
return err
}
defer os.RemoveAll(dir)
r, err := zip.OpenReader(bundle)
if err != nil {
return fmt.Errorf("unable to open upgrade bundle %s: %w", bundle, err)
}
defer r.Close()
links := []*zip.File{}
for _, f := range r.File {
if strings.HasSuffix(f.Name, "/") {
d := filepath.Join(dir, f.Name)
err := os.MkdirAll(d, 0o755)
if err != nil {
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
continue
}
if f.Mode()&os.ModeSymlink != 0 {
// Defer links to the end
links = append(links, f)
continue
}
src, err := f.Open()
if err != nil {
return fmt.Errorf("failed to open bundle file %s: %w", f.Name, err)
}
destName := filepath.Join(dir, f.Name)
// Verify directory first
d := filepath.Dir(destName)
if _, err := os.Stat(d); err != nil {
err := os.MkdirAll(d, 0o755)
if err != nil {
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
}
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("failed to open output file %s: %w", destName, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
}
}
for _, f := range links {
src, err := f.Open()
if err != nil {
return err
}
buf, err := io.ReadAll(src)
if err != nil {
return err
}
link := string(buf)
if link[0] == '/' {
return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link)
}
if strings.HasPrefix(filepath.Join(filepath.Dir(f.Name), link), "..") {
return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link)
}
if err = os.Symlink(link, filepath.Join(dir, f.Name)); err != nil {
return err
}
}
if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil {
return fmt.Errorf("signature verification failed: %s", err)
}
return nil
}
// If we detect an upgrade bundle, attempt to upgrade at startup
func DoUpgradeAtStartup() error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
}
if BundlePath == "" {
return fmt.Errorf("unable to upgrade at startup, app in development mode")
}
// [Re]verify before proceeding
if err := VerifyDownload(); err != nil {
_ = os.Remove(bundle)
slog.Warn("verification failure", "bundle", bundle, "error", err)
return nil
}
slog.Info("performing update at startup", "bundle", bundle)
return DoUpgrade(false)
}
func getStagedUpdate() string {
files, err := filepath.Glob(filepath.Join(UpdateStageDir, "*", "*.zip"))
if err != nil {
slog.Debug("failed to lookup downloads", "error", err)
return ""
}
if len(files) == 0 {
return ""
} else if len(files) > 1 {
// Shouldn't happen
slog.Warn("multiple update downloads found, using first one", "bundles", files)
}
return files[0]
}
func IsUpdatePending() bool {
return getStagedUpdate() != ""
}
func chownWithAuthorization(user string) bool {
u := C.CString(user)
defer C.free(unsafe.Pointer(u))
return (bool)(C.chownWithAuthorization(u))
}
func verifyExtractedBundle(path string) error {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
resp := C.verifyExtractedBundle(p)
if resp == nil {
return nil
}
return errors.New(C.GoString(resp))
}
//export goLogInfo
func goLogInfo(msg *C.cchar_t) {
slog.Info(C.GoString(msg))
}
//export goLogDebug
func goLogDebug(msg *C.cchar_t) {
slog.Debug(C.GoString(msg))
}
func alreadyMoved() string {
// Respect users intent if they chose "keep" vs. "replace" when dragging to Applications
installedAppPaths, err := filepath.Glob(filepath.Join(
strings.TrimSuffix(SystemWidePath, filepath.Ext(SystemWidePath))+"*"+filepath.Ext(SystemWidePath),
"Contents", "MacOS", "Ollama"))
if err != nil {
slog.Warn("failed to lookup installed app paths", "error", err)
return ""
}
exe, err := os.Executable()
if err != nil {
slog.Warn("failed to resolve executable", "error", err)
return ""
}
self, err := os.Stat(exe)
if err != nil {
slog.Warn("failed to stat running executable", "path", exe, "error", err)
return ""
}
selfSys := self.Sys().(*syscall.Stat_t)
for _, installedAppPath := range installedAppPaths {
app, err := os.Stat(installedAppPath)
if err != nil {
slog.Debug("failed to stat installed app path", "path", installedAppPath, "error", err)
continue
}
appSys := app.Sys().(*syscall.Stat_t)
if appSys.Ino == selfSys.Ino {
return filepath.Dir(filepath.Dir(filepath.Dir(installedAppPath)))
}
}
return ""
}
#import <Cocoa/Cocoa.h>
// TODO make these macros so we can extract line numbers from the native code
void appLogInfo(NSString *msg);
void appLogDebug(NSString *msg);
void goLogInfo(const char *msg);
void goLogDebug(const char *msg);
AuthorizationRef getAuthorization(NSString *authorizationPrompt,
NSString *right);
AuthorizationRef getAppInstallAuthorization();
const char* verifyExtractedBundle(char *path);
bool chownWithAuthorization(const char *user);
\ No newline at end of file
#import "updater_darwin.h"
#import <AppKit/AppKit.h>
#import <Cocoa/Cocoa.h>
#import <CoreServices/CoreServices.h>
#import <Security/Security.h>
#import <ServiceManagement/ServiceManagement.h>
void appLogInfo(NSString *msg) {
NSLog(@"%@", msg);
goLogInfo([msg UTF8String]);
}
void appLogDebug(NSString *msg) {
NSLog(@"%@", msg);
goLogDebug([msg UTF8String]);
}
NSString *SystemWidePath = @"/Applications/Ollama.app";
// TODO - how to detect if the user has admin access?
// Possible APIs to explore:
// - SFAuthorization
// - CSIdentityQueryCreateForCurrentUser + CSIdentityQueryCreateForName(NULL,
// CFSTR("admin"), kCSIdentityQueryStringEquals, kCSIdentityClassGroup,
// CSGetDefaultIdentityAuthority());
// Caller must call AuthorizationFree(authRef, kAuthorizationFlagDestroyRights)
// once finished
// TODO consider a struct response type to capture user cancel scenario from
// other error/failure scenarios
AuthorizationRef getAuthorization(NSString *authorizationPrompt,
NSString *right) {
appLogInfo([NSString stringWithFormat:@"XXX in getAuthorization"]);
AuthorizationRef authRef = NULL;
OSStatus err = AuthorizationCreate(NULL, kAuthorizationEmptyEnvironment,
kAuthorizationFlagDefaults, &authRef);
if (err != errAuthorizationSuccess) {
appLogInfo([NSString
stringWithFormat:
@"Failed to create authorization reference. Status = %d",
err]);
return NULL;
}
NSString *bundleIdentifier = [[NSBundle mainBundle] bundleIdentifier];
NSString *rightNameString =
[NSString stringWithFormat:@"%@.%@", bundleIdentifier, right];
const char *rightName = [rightNameString UTF8String];
appLogInfo([NSString stringWithFormat:@"XXX requesting right %@", rightNameString]);
OSStatus getRightResult = AuthorizationRightGet(rightName, NULL);
if (getRightResult == errAuthorizationDenied) {
// Create or update the right if it doesn't exist
if (AuthorizationRightSet(
authRef, rightName,
(__bridge CFTypeRef _Nonnull)(
@(kAuthorizationRuleAuthenticateAsAdmin)),
(__bridge CFStringRef _Nullable)(authorizationPrompt), NULL,
NULL) != errAuthorizationSuccess) {
appLogInfo([NSString
stringWithFormat:
@"Failed to set right for moving to /Applications"]);
AuthorizationFree(authRef, kAuthorizationFlagDestroyRights);
return NULL;
}
}
AuthorizationItem rightItem = {
.name = rightName, .valueLength = 0, .value = NULL, .flags = 0};
AuthorizationRights rights = {.count = 1, .items = &rightItem};
AuthorizationFlags flags =
(AuthorizationFlags)(kAuthorizationFlagExtendRights |
kAuthorizationFlagInteractionAllowed);
err = AuthorizationCopyRights(authRef, &rights, NULL, flags, NULL);
if (err != errAuthorizationSuccess) {
if (err == errAuthorizationCanceled) {
appLogInfo([NSString
stringWithFormat:@"User cancelled authorization. Status = %d",
err]);
// TODO bubble up user cancel/reject so we can keep track
} else {
appLogInfo([NSString
stringWithFormat:@"failed to grant authorization. Status = %d",
err]);
}
AuthorizationFree(authRef, kAuthorizationFlagDestroyRights);
return NULL;
}
return authRef;
}
AuthorizationRef getAppInstallAuthorization() {
return getAuthorization(
@"Ollama needs additional permission to move or update itself as a "
"system-wide Application",
@"systemApplication");
}
bool chownWithAuthorization(const char *user) {
AuthorizationRef authRef = getAppInstallAuthorization();
if (authRef == NULL) {
return NO;
}
const char *chownTool = "/usr/sbin/chown";
const char *chownArgs[] = {"-R", user, [SystemWidePath UTF8String], NULL};
FILE *pipe = NULL;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
OSStatus err = AuthorizationExecuteWithPrivileges(
authRef, chownTool, kAuthorizationFlagDefaults,
(char *const *)chownArgs, &pipe);
#pragma clang diagnostic pop
if (err != errAuthorizationSuccess) {
appLogInfo([NSString
stringWithFormat:@"Failed to update ownership of %@ Status = %d",
SystemWidePath, err]);
AuthorizationFree(authRef, kAuthorizationFlagDestroyRights);
return NO;
}
// wait for the command to finish
while (pipe && !feof(pipe)) {
fgetc(pipe);
}
if (pipe) {
fclose(pipe);
}
appLogDebug([NSString stringWithFormat:@"XXX finished chown"]);
AuthorizationFree(authRef, kAuthorizationFlagDestroyRights);
return true;
}
// nil if bundle is good, error string otherwise
const char *verifyExtractedBundle(char *path) {
NSString *p = [NSString stringWithFormat:@"%s", path];
appLogDebug([NSString stringWithFormat:@"verifyExtractedBundle: %@", p]);
SecStaticCodeRef staticCode = NULL;
OSStatus result = SecStaticCodeCreateWithPath(
CFURLCreateFromFileSystemRepresentation(
(__bridge CFAllocatorRef)(kCFAllocatorSystemDefault),
(const UInt8 *)path, strlen(path), kCFStringEncodingMacRoman),
kSecCSDefaultFlags, &staticCode);
if (result != noErr) {
NSString *failureReason =
CFBridgingRelease(SecCopyErrorMessageString(result, NULL));
appLogDebug([NSString
stringWithFormat:@"Failed to get static code for bundle: %@",
failureReason]);
if (staticCode != NULL)
CFRelease(staticCode);
return [[NSString
stringWithFormat:@"Failed to get static code for bundle: %@",
failureReason] UTF8String];
}
CFErrorRef validityError = NULL;
result = SecStaticCodeCheckValidityWithErrors(
staticCode, kSecCSCheckAllArchitectures, NULL, &validityError);
if (result != noErr) {
NSString *failureReason =
CFBridgingRelease(SecCopyErrorMessageString(result, NULL));
appLogDebug([NSString
stringWithFormat:@"Signatures did not verify on bundle: %@",
failureReason]);
// TODO - consider extracting additional details from validityError
if (validityError != NULL)
CFRelease(validityError);
return [[NSString
stringWithFormat:@"Signatures did not verify on bundle: %@",
failureReason] UTF8String];
}
appLogDebug([NSString stringWithFormat:@"bundle passed verification"]);
return NULL;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment